From 2d97b502f0ddf4727b958d880479bfdf5c8aa075 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 1 Aug 2023 17:10:25 +0100 Subject: [PATCH 001/151] rename Parameter to Estimand --- docs/src/index.md | 2 +- docs/src/user_guide.md | 16 ++++----- src/cache.jl | 28 +++++++-------- src/configuration.jl | 18 +++++----- src/estimate.jl | 2 +- src/parameters.jl | 44 +++++++++++------------ src/utils.jl | 2 +- test/cache.jl | 2 +- test/data/parameters.yaml | 2 +- test/data/parameters_with_covariates.yaml | 2 +- test/parameters.jl | 2 +- 11 files changed, 60 insertions(+), 60 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 76effa29..c0d2c8ab 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -70,7 +70,7 @@ IATE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}[\mathbb{E}[Y|T_1=1, T_2= - \mathbb{E}[\mathbb{E}[Y|T_1=0, T_2=1, W]] + \mathbb{E}[\mathbb{E}[Y|T_1=0, T_2=0, W]] ``` -### Any function of the previous Parameters +### Any function of the previous Estimands As a result of Julia's automatic differentiation facilities, given a set of already estimated parameters $(\Psi_1, ..., \Psi_k)$, we can automatically compute an estimator for $f(\Psi_1, ..., \Psi_k)$. diff --git a/docs/src/user_guide.md b/docs/src/user_guide.md index f14cfa17..8821fdb2 100644 --- a/docs/src/user_guide.md +++ b/docs/src/user_guide.md @@ -6,7 +6,7 @@ CurrentModule = TMLE ## The Dataset -TMLE.jl should be compatible with any dataset respecting the [Tables.jl](https://tables.juliadata.org/stable/) interface, that is, a structure like a `NamedTuple` or a `DataFrame` from [DataFrames.jl](https://dataframes.juliadata.org/stable/) should work. In the remainder of this section, we will be working with the same dataset and see that we can ask very many questions (Parameters) from it. +TMLE.jl should be compatible with any dataset respecting the [Tables.jl](https://tables.juliadata.org/stable/) interface, that is, a structure like a `NamedTuple` or a `DataFrame` from [DataFrames.jl](https://dataframes.juliadata.org/stable/) should work. In the remainder of this section, we will be working with the same dataset and see that we can ask very many questions (Estimands) from it. ```@example user-guide using Random @@ -43,9 +43,9 @@ nothing # hide ``` !!! note "Note on Treatment variables" - It should be noted that the treatment variables **must** be [categorical](https://categoricalarrays.juliadata.org/stable/). Since the treatment is also used as an input to the ``Q_0`` learner, a `OneHotEncoder` is used by default (see [The Nuisance Parameters](@ref) section). If a numerical representation is more appropriate (ordinal variables), use the keyword `ordered=true` when constructing a `categorical` vector, the `OneHotEncoder` will ignore those variables and their floating point representation will be used. + It should be noted that the treatment variables **must** be [categorical](https://categoricalarrays.juliadata.org/stable/). Since the treatment is also used as an input to the ``Q_0`` learner, a `OneHotEncoder` is used by default (see [The Nuisance Estimands](@ref) section). If a numerical representation is more appropriate (ordinal variables), use the keyword `ordered=true` when constructing a `categorical` vector, the `OneHotEncoder` will ignore those variables and their floating point representation will be used. -## The Nuisance Parameters +## The Nuisance Estimands As described in the [Mathematical setting](@ref) section, we need to provide an estimation strategy for both $Q_0$ and $G_0$. For illustration purposes, we here consider a simple strategy where both models are assumed to be generalized linear models. However this is not the recommended practice since there is little chance those functions are actually linear, and theoretical guarantees associated with TMLE may fail to hold. We recommend instead the use of Super Learning which is exemplified in [The benefits of Super Learning](@ref). @@ -67,7 +67,7 @@ nothing # hide The `NuisanceSpec` struct also holds a specification for the `OneHotEncoder` necessary for the encoding of treatment variables and a generalized linear model for the fluctuation model. Unless you know what you are doing, there is little chance you need to modify those. -## Parameters +## Estimands ### The Conditional mean @@ -181,7 +181,7 @@ iate_result, _ = tmle(Ψ, η_spec, dataset, verbosity=0) OneSampleTTest(iate_result.tmle) ``` -### Composing Parameters +### Composing Estimands By leveraging the multivariate Central Limit Theorem and Julia's automatic differentiation facilities, we can actually compute any new parameter estimate from a set of already estimated parameters. By default, TMLE.jl will use [Zygote](https://fluxml.ai/Zygote.jl/latest/) but since we are using [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) you can change the backend to your favorite AD system. @@ -214,12 +214,12 @@ and the variance: var(composed_ate_result), var(ate_result.tmle) ``` -### Reading Parameters from YAML files +### Reading Estimands from YAML files It may be useful to read and write parameters to text files. We provide this functionality using the YAML format and the `parameters_from_yaml` and `parameters_to_yaml` functions. Since YAML is quite sensitive to identation, it is not recommended to write those files by hand. As an illustration, an example of such a file is given below: ```yaml -Parameters: +Estimands: - target: Y1 treatment: (T2 = (case = 1, control = 0), T1 = (case = 1, control = 0)) confounders: @@ -320,5 +320,5 @@ Since we have only updated $G$'s specification, only this model is fitted again. ### General behaviour -Any change to either the `Parameter` of interest or the `NuisanceSpec` structures will trigger an update of the cache. If you have a predefined +Any change to either the `Estimand` of interest or the `NuisanceSpec` structures will trigger an update of the cache. If you have a predefined list of parameters to estimate, you can optimize the estimation order of those parameters by calling `optimize_ordering!` or `optimize_ordering`. This will order the parameters in order to save the number of required nuisance parameters fits. diff --git a/src/cache.jl b/src/cache.jl index 5c3042f5..abf69750 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -10,12 +10,12 @@ saving computations. """ mutable struct TMLECache data - η::NuisanceParameters - Ψ::Parameter + η::NuisanceEstimands + Ψ::Estimand η_spec::NuisanceSpec function TMLECache(dataset) data = Dict{Symbol, Any}(:source => dataset) - η = NuisanceParameters(nothing, nothing, nothing, nothing) + η = NuisanceEstimands(nothing, nothing, nothing, nothing) new(data, η) end end @@ -39,7 +39,7 @@ function check_treatment_settings(setting, levels, treatment_name) " in Ψ does not match any level of the corresponding variable in the dataset: ", string.(levels)))) end -function check_treatment_values(cache::TMLECache, Ψ::Parameter) +function check_treatment_values(cache::TMLECache, Ψ::Estimand) for treatment_name in treatments(Ψ) treatment_levels = string.(levels(Tables.getcolumn(cache.data[:source], treatment_name))) treatment_settings = getproperty(Ψ.treatment, treatment_name) @@ -48,7 +48,7 @@ function check_treatment_values(cache::TMLECache, Ψ::Parameter) end -function update!(cache::TMLECache, Ψ::Parameter) +function update!(cache::TMLECache, Ψ::Estimand) check_treatment_values(cache, Ψ) any_variable_changed = false if !isdefined(cache, :Ψ) @@ -102,7 +102,7 @@ function update!(cache::TMLECache, η_spec::NuisanceSpec) return cache end -function update!(cache::TMLECache, Ψ::Parameter, η_spec::NuisanceSpec) +function update!(cache::TMLECache, Ψ::Estimand, η_spec::NuisanceSpec) update!(cache, Ψ) update!(cache, η_spec) end @@ -119,13 +119,13 @@ const TMLE_ARGS_DOCS = """ """ """ - tmle(Ψ::Parameter, η_spec::NuisanceSpec, dataset; kwargs...) + tmle(Ψ::Estimand, η_spec::NuisanceSpec, dataset; kwargs...) Build a TMLECache struct and runs TMLE. $TMLE_ARGS_DOCS """ -function tmle(Ψ::Parameter, η_spec::NuisanceSpec, dataset; kwargs...) +function tmle(Ψ::Estimand, η_spec::NuisanceSpec, dataset; kwargs...) cache = TMLECache(dataset) update!(cache, Ψ, η_spec) return tmle!(cache; kwargs...) @@ -157,13 +157,13 @@ function tmle!(cache::TMLECache; verbosity=1, threshold=1e-8, weighted_fluctuati end """ - tmle!(cache::TMLECache, Ψ::Parameter; verbosity=1, threshold=1e-8) + tmle!(cache::TMLECache, Ψ::Estimand; verbosity=1, threshold=1e-8) Updates the TMLECache with Ψ and runs TMLE. $TMLE_ARGS_DOCS """ -function tmle!(cache::TMLECache, Ψ::Parameter; kwargs...) +function tmle!(cache::TMLECache, Ψ::Estimand; kwargs...) update!(cache, Ψ) tmle!(cache; kwargs...) end @@ -181,25 +181,25 @@ function tmle!(cache::TMLECache, η_spec::NuisanceSpec; kwargs...) end """ - tmle!(cache::TMLECache, Ψ::Parameter, η_spec::NuisanceSpec; kwargs...) + tmle!(cache::TMLECache, Ψ::Estimand, η_spec::NuisanceSpec; kwargs...) Updates the TMLECache with η_spec and Ψ and runs TMLE. $TMLE_ARGS_DOCS """ -function tmle!(cache::TMLECache, Ψ::Parameter, η_spec::NuisanceSpec; kwargs...) +function tmle!(cache::TMLECache, Ψ::Estimand, η_spec::NuisanceSpec; kwargs...) update!(cache, Ψ, η_spec) tmle!(cache; kwargs...) end """ - tmle!(cache::TMLECache, η_spec::NuisanceSpec, Ψ::Parameter; kwargs...) + tmle!(cache::TMLECache, η_spec::NuisanceSpec, Ψ::Estimand; kwargs...) Updates the TMLECache with η_spec and Ψ and runs TMLE. $TMLE_ARGS_DOCS """ -function tmle!(cache::TMLECache, η_spec::NuisanceSpec, Ψ::Parameter; kwargs...) +function tmle!(cache::TMLECache, η_spec::NuisanceSpec, Ψ::Estimand; kwargs...) update!(cache, Ψ, η_spec) tmle!(cache; kwargs...) end diff --git a/src/configuration.jl b/src/configuration.jl index daa8b458..2f3b95b6 100644 --- a/src/configuration.jl +++ b/src/configuration.jl @@ -1,7 +1,7 @@ -Configurations.from_dict(::Type{<:Parameter}, ::Type{Symbol}, s) = Symbol(s) -Configurations.from_dict(::Type{<:Parameter}, ::Type{NamedTuple}, s) = eval(Meta.parse(s)) +Configurations.from_dict(::Type{<:Estimand}, ::Type{Symbol}, s) = Symbol(s) +Configurations.from_dict(::Type{<:Estimand}, ::Type{NamedTuple}, s) = eval(Meta.parse(s)) -function param_to_dict(Ψ::Parameter) +function param_to_dict(Ψ::Estimand) d = to_dict(Ψ, YAMLStyle) d["type"] = typeof(Ψ) return d @@ -10,13 +10,13 @@ end """ parameters_from_yaml(filepath) -Read estimation `Parameters` from YAML file. +Read estimation `Estimands` from YAML file. """ function parameters_from_yaml(filepath) yaml_dict = YAML.load_file(filepath; dicttype=Dict{String, Any}) - params_dicts = yaml_dict["Parameters"] + params_dicts = yaml_dict["Estimands"] nparams = size(params_dicts, 1) - parameters = Vector{Parameter}(undef, nparams) + parameters = Vector{Estimand}(undef, nparams) for index in 1:nparams d = params_dicts[index] type = pop!(d, "type") @@ -26,11 +26,11 @@ function parameters_from_yaml(filepath) end """ - parameters_to_yaml(filepath, parameters::Vector{<:Parameter}) + parameters_to_yaml(filepath, parameters::Vector{<:Estimand}) -Write estimation `Parameters` to YAML file. +Write estimation `Estimands` to YAML file. """ function parameters_to_yaml(filepath, parameters) - d = Dict("Parameters" => [param_to_dict(Ψ) for Ψ in parameters]) + d = Dict("Estimands" => [param_to_dict(Ψ) for Ψ in parameters]) YAML.write_file(filepath, d) end \ No newline at end of file diff --git a/src/estimate.jl b/src/estimate.jl index e3008cb6..b3516726 100644 --- a/src/estimate.jl +++ b/src/estimate.jl @@ -11,7 +11,7 @@ struct ComposedTMLE{T<:AbstractFloat} <: AbstractTMLE σ̂::Matrix{T} end -struct TMLEResult{P <: Parameter, T<:AbstractFloat} +struct TMLEResult{P <: Estimand, T<:AbstractFloat} parameter::P tmle::ALEstimate{T} onestep::ALEstimate{T} diff --git a/src/parameters.jl b/src/parameters.jl index 7eabc737..9b878760 100644 --- a/src/parameters.jl +++ b/src/parameters.jl @@ -13,12 +13,12 @@ const causal_graph = """ """ ##################################################################### -### Abstract Parameter ### +### Abstract Estimand ### ##################################################################### """ -A Parameter is a functional on distribution space Ψ: ℳ → ℜ. +A Estimand is a functional on distribution space Ψ: ℳ → ℜ. """ -abstract type Parameter end +abstract type Estimand end ##################################################################### ### Conditional Mean ### @@ -57,7 +57,7 @@ CM₂ = CM( ) ``` """ -@option struct CM <: Parameter +@option struct CM <: Estimand target::Symbol treatment::NamedTuple confounders::Vector{Symbol} @@ -101,7 +101,7 @@ ATE₂ = ATE( ) ``` """ -@option struct ATE <: Parameter +@option struct ATE <: Estimand target::Symbol treatment::NamedTuple confounders::Vector{Symbol} @@ -139,7 +139,7 @@ IATE₁ = IATE( ) ``` """ -@option struct IATE <: Parameter +@option struct IATE <: Estimand target::Symbol treatment::NamedTuple confounders::Vector{Symbol} @@ -148,11 +148,11 @@ end ##################################################################### -### Nuisance Parameters ### +### Nuisance Estimands ### ##################################################################### """ -# NuisanceParameters +# NuisanceEstimands The set of estimators that need to be estimated but are not of direct interest. @@ -169,7 +169,7 @@ All fields are MLJBase.Machine. - H: A one-hot-encoder categorical treatments - F: A generalized linear model to fluctuate E[Y|X] """ -mutable struct NuisanceParameters +mutable struct NuisanceEstimands Q::Union{Nothing, MLJBase.Machine} G::Union{Nothing, MLJBase.Machine} H::Union{Nothing, MLJBase.Machine} @@ -206,33 +206,33 @@ NuisanceSpec(Q, G; H=TreatmentTransformer(), F=F_model(target_scitype(Q)), cache selectcols(data, cols) = data |> TableOperations.select(cols...) |> Tables.columntable -confounders(Ψ::Parameter) = Ψ.confounders +confounders(Ψ::Estimand) = Ψ.confounders confounders(dataset, Ψ) = selectcols(dataset, confounders(Ψ)) -covariates(Ψ::Parameter) = Ψ.covariates +covariates(Ψ::Estimand) = Ψ.covariates covariates(dataset, Ψ) = selectcols(dataset, covariates(Ψ)) -treatments(Ψ::Parameter) = collect(keys(Ψ.treatment)) +treatments(Ψ::Estimand) = collect(keys(Ψ.treatment)) treatments(dataset, Ψ) = selectcols(dataset, treatments(Ψ)) -target(Ψ::Parameter) = Ψ.target +target(Ψ::Estimand) = Ψ.target target(dataset, Ψ) = Tables.getcolumn(dataset, target(Ψ)) -treatment_and_confounders(Ψ::Parameter) = vcat(confounders(Ψ), treatments(Ψ)) +treatment_and_confounders(Ψ::Estimand) = vcat(confounders(Ψ), treatments(Ψ)) -confounders_and_covariates(Ψ::Parameter) = vcat(confounders(Ψ), covariates(Ψ)) +confounders_and_covariates(Ψ::Estimand) = vcat(confounders(Ψ), covariates(Ψ)) confounders_and_covariates(dataset, Ψ) = selectcols(dataset, confounders_and_covariates(Ψ)) """ Merges together confounders, covariates and floating point representation of treatments. """ -Qinputs(H, dataset, Ψ::Parameter) = merge( +Qinputs(H, dataset, Ψ::Estimand) = merge( columntable(confounders_and_covariates(dataset, Ψ)), MLJBase.transform(H, treatments(dataset, Ψ)) ) -allcolumns(Ψ::Parameter) = vcat(confounders_and_covariates(Ψ), treatments(Ψ), target(Ψ)) +allcolumns(Ψ::Estimand) = vcat(confounders_and_covariates(Ψ), treatments(Ψ), target(Ψ)) F_model(::Type{<:AbstractVector{<:MLJBase.Continuous}}) = LinearRegressor(fit_intercept=false, offsetcol = :offset) @@ -247,7 +247,7 @@ namedtuples_from_dicts(d::Dict) = NamedTuple{Tuple(keys(d))}([namedtuples_from_dicts(val) for val in values(d)]) -function param_key(Ψ::Parameter) +function param_key(Ψ::Estimand) return ( join(Ψ.confounders, "_"), join(keys(Ψ.treatment), "_"), @@ -257,7 +257,7 @@ function param_key(Ψ::Parameter) end """ - optimize_ordering!(parameters::Vector{<:Parameter}) + optimize_ordering!(parameters::Vector{<:Estimand}) Reorders the given parameters so that most nuisance parameters fits can be reused. Given the assumed causal graph: @@ -269,11 +269,11 @@ and the requirements to estimate both p(T|W) and E[Y|W, T, C]. A natural ordering of the parameters in order to save computations is given by the following variables ordering: (W, T, Y, C) """ -optimize_ordering!(parameters::Vector{<:Parameter}) = sort!(parameters, by=param_key) +optimize_ordering!(parameters::Vector{<:Estimand}) = sort!(parameters, by=param_key) """ - optimize_ordering(parameters::Vector{<:Parameter}) + optimize_ordering(parameters::Vector{<:Estimand}) See [`optimize_ordering!`](@ref) """ -optimize_ordering(parameters::Vector{<:Parameter}) = sort(parameters, by=param_key) \ No newline at end of file +optimize_ordering(parameters::Vector{<:Estimand}) = sort(parameters, by=param_key) \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 70d79408..8dd85a54 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -44,7 +44,7 @@ function nomissing(table, columns) return nomissing(columns) end -ncases(value, Ψ::Parameter) = sum(value[i] == Ψ.treatment[i].case for i in eachindex(value)) +ncases(value, Ψ::Estimand) = sum(value[i] == Ψ.treatment[i].case for i in eachindex(value)) function indicator_fns(Ψ::IATE, f::Function) N = length(treatments(Ψ)) diff --git a/test/cache.jl b/test/cache.jl index adebbbe7..44c9f1e7 100644 --- a/test/cache.jl +++ b/test/cache.jl @@ -27,7 +27,7 @@ function fakefill!(cache) X = (x=rand(n),) y = rand(n) mach = machine(LinearRegressor(), X, y) - cache.η = TMLE.NuisanceParameters(mach, mach, mach, mach) + cache.η = TMLE.NuisanceEstimands(mach, mach, mach, mach) end function fakecache() diff --git a/test/data/parameters.yaml b/test/data/parameters.yaml index 6f8f7939..9a1fd785 100644 --- a/test/data/parameters.yaml +++ b/test/data/parameters.yaml @@ -9,7 +9,7 @@ Y: - Y1 - Y2 -Parameters: +Estimands: - name: IATE T1: case: 1 diff --git a/test/data/parameters_with_covariates.yaml b/test/data/parameters_with_covariates.yaml index a8678d93..6d678d2b 100644 --- a/test/data/parameters_with_covariates.yaml +++ b/test/data/parameters_with_covariates.yaml @@ -10,7 +10,7 @@ W: Y: - Y1 -Parameters: +Estimands: - name: IATE T1: case: 2 diff --git a/test/parameters.jl b/test/parameters.jl index 908e143e..4df32712 100644 --- a/test/parameters.jl +++ b/test/parameters.jl @@ -1,4 +1,4 @@ -module TestParameters +module TestEstimands using Test using TMLE From d1186495b055d0ae96d07ff03f9724615be123f4 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 1 Aug 2023 17:15:15 +0100 Subject: [PATCH 002/151] rename parameter to estimand --- README.md | 2 +- docs/src/index.md | 14 +-- docs/src/user_guide.md | 28 ++--- examples/introduction_to_targeted_learning.jl | 12 +- examples/super_learning.jl | 4 +- src/TMLE.jl | 12 +- src/cache.jl | 18 +-- src/configuration.jl | 16 +-- src/{parameters.jl => estimands.jl} | 14 +-- src/estimate.jl | 2 +- test/cache.jl | 6 +- test/configuration.jl | 16 +-- test/{parameters.jl => estimands.jl} | 10 +- test/runtests.jl | 2 +- test/utils.jl | 4 +- test/warm_restart.jl | 104 +++++++++--------- 16 files changed, 132 insertions(+), 132 deletions(-) rename src/{parameters.jl => estimands.jl} (93%) rename test/{parameters.jl => estimands.jl} (94%) diff --git a/README.md b/README.md index 507dce0e..4cd969be 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ ![Codecov](https://img.shields.io/codecov/c/github/TARGENE/TMLE.jl/main) ![GitHub release (latest SemVer)](https://img.shields.io/github/v/release/TARGENE/TMLE.jl) -Targeted Minimum Loss-Based Estimation (TMLE) is a framework for efficient estimation of pathwise differentiable parameters. The goal of this package is to provide an implementation on top of the [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) framework. +Targeted Minimum Loss-Based Estimation (TMLE) is a framework for efficient estimation of pathwise differentiable estimands. The goal of this package is to provide an implementation on top of the [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) framework. **New to TMLE.jl?** The API design is still work in progress but you can already [get started now](https://targene.github.io/TMLE.jl/stable/) diff --git a/docs/src/index.md b/docs/src/index.md index c0d2c8ab..4ee9c188 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -6,7 +6,7 @@ CurrentModule = TMLE ## Overview -TMLE.jl is a Julia implementation of the Targeted Minimum Loss-Based Estimation ([TMLE](https://link.springer.com/book/10.1007/978-1-4419-9782-1)) framework. If you are interested in efficient and unbiased estimation of some causal effect, you are in the right place. Since TMLE uses machine-learning methods to estimate nuisance parameters, the present package is based upon [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/). This means that any model respecting the MLJ interface can be used for the estimation of nuisance parameters. +TMLE.jl is a Julia implementation of the Targeted Minimum Loss-Based Estimation ([TMLE](https://link.springer.com/book/10.1007/978-1-4419-9782-1)) framework. If you are interested in efficient and unbiased estimation of some causal effect, you are in the right place. Since TMLE uses machine-learning methods to estimate nuisance estimands, the present package is based upon [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/). This means that any model respecting the MLJ interface can be used for the estimation of nuisance estimands. ## Mathematical setting @@ -31,7 +31,7 @@ This graph encodes a factorization of the joint probability distribution $P_0$ t P_0(Y, T, W, C) = P_0(Y|T, W, C) \cdot P_0(T|W) \cdot P_0(W) \cdot P_0(C) ``` -Usually, we don't have much information about $P_0$ and don't want to make unrealistic assumptions, thus we will simply state that $P_0 \in \mathcal{M}$ where $\mathcal{M}$ is the space of all probability distributions. It turns out that most target quantities of interest in causal inference can be expressed as real-valued functions of $P_0$, denoted by: $\Psi:\mathcal{M} \rightarrow \Re$. TMLE works by finding a suitable estimator $\hat{P}_n$ of $P_0$ and then simply substituting it into $\Psi$. Fortunately, it is often non necessary to come up with an estimator of the joint distribution $P_0$, instead only subparts of it are required. Those are called nuisance parameters because they are not of direct interest and for our purposes, only two nuisance parameters are necessary: +Usually, we don't have much information about $P_0$ and don't want to make unrealistic assumptions, thus we will simply state that $P_0 \in \mathcal{M}$ where $\mathcal{M}$ is the space of all probability distributions. It turns out that most target quantities of interest in causal inference can be expressed as real-valued functions of $P_0$, denoted by: $\Psi:\mathcal{M} \rightarrow \Re$. TMLE works by finding a suitable estimator $\hat{P}_n$ of $P_0$ and then simply substituting it into $\Psi$. Fortunately, it is often non necessary to come up with an estimator of the joint distribution $P_0$, instead only subparts of it are required. Those are called nuisance estimands because they are not of direct interest and for our purposes, only two nuisance estimands are necessary: ```math Q_0(t, w, c) = \mathbb{E}[Y|T=t, W=w, C=c] @@ -43,7 +43,7 @@ and G_0(t, w) = P(T=t|W=w) ``` -Specifying which machine-learning method to use for each of those nuisance parameter is the only ingredient you will need to run a targeted estimation for any of the following quantities. Those quantities have a causal interpretation under suitable assumptions that we do not discuss here. +Specifying which machine-learning method to use for each of those nuisance estimand is the only ingredient you will need to run a targeted estimation for any of the following quantities. Those quantities have a causal interpretation under suitable assumptions that we do not discuss here. ### The Conditional Mean @@ -72,7 +72,7 @@ IATE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}[\mathbb{E}[Y|T_1=1, T_2= ### Any function of the previous Estimands -As a result of Julia's automatic differentiation facilities, given a set of already estimated parameters $(\Psi_1, ..., \Psi_k)$, we can automatically compute an estimator for $f(\Psi_1, ..., \Psi_k)$. +As a result of Julia's automatic differentiation facilities, given a set of already estimated estimands $(\Psi_1, ..., \Psi_k)$, we can automatically compute an estimator for $f(\Psi_1, ..., \Psi_k)$. ## Installation @@ -87,8 +87,8 @@ Pkg> add TMLE To run an estimation procedure, we need 3 ingredients: - A dataset -- A parameter of interest -- A nuisance parameters specification +- A estimand of interest +- A nuisance estimands specification For instance, assume the following simple data generating process: @@ -137,7 +137,7 @@ Note that in this example the ATE can be computed exactly and is given by: ATE_{0 \rightarrow 1}(P_0) = \mathbb{E}[1 + 3 - W] - \mathbb{E}[1] = 3 - \mathbb{E}[W] = 2.5 ``` -We next need to define the strategy for learning the nuisance parameters $Q$ and $G$, here we keep things simple and simply use generalized linear models: +We next need to define the strategy for learning the nuisance estimands $Q$ and $G$, here we keep things simple and simply use generalized linear models: ```@example quick-start η_spec = NuisanceSpec( diff --git a/docs/src/user_guide.md b/docs/src/user_guide.md index 8821fdb2..44ba6ed7 100644 --- a/docs/src/user_guide.md +++ b/docs/src/user_guide.md @@ -60,7 +60,7 @@ nothing # hide ``` !!! note "Practical note on $Q_0$ and $G_0$" - The models chosen for the nuisance parameters should be adapted to the outcome they target: + The models chosen for the nuisance estimands should be adapted to the outcome they target: - ``Q_0 = \mathbf{E}_0[Y|T=t, W=w, C=c]``, $Y$ can be either continuous or categorical, in our example it is continuous and a `LinearRegressor` is a correct choice. - ``G_0 = P_0(T|W)``, $T$ are always categorical variables. If T is a single treatment with only 2 levels, a logistic regression will work, if T has more than two levels a multinomial regression for instance would be suitable. If there are more than 2 treatment variables (with potentially more than 2 levels), then, the joint distribution is learnt and a multinomial regression would also work. In any case, the `LogisticClassifier` from [MLJLinearModels](https://juliaai.github.io/MLJLinearModels.jl/stable/) is a suitable choice. For more information on available models and their uses, we refer to the [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) documentation @@ -71,7 +71,7 @@ The `NuisanceSpec` struct also holds a specification for the `OneHotEncoder` nec ### The Conditional mean -We are now ready to move to the definition of the parameters of interest. The most basic type of parameter is the conditional mean of the target given the treatment: +We are now ready to move to the definition of the estimands of interest. The most basic type of estimand is the conditional mean of the target given the treatment: ```math CM_t(P) = \mathbb{E}[\mathbb{E}[Y|T=t, W]] @@ -88,7 +88,7 @@ The treatment does not have to be restricted to a single variable, we can define nothing # hide ``` -In this case, we can compute the exact value of the parameter: +In this case, we can compute the exact value of the estimand: ```math CM_{T_1=1, T_2=1} = 1 + 2\mathbb{E}[W₁] + 3\mathbb{E}[W₂] - 4\mathbb{E}[C₁] - 2\mathbb{E}[W₂] = 0.5 @@ -183,7 +183,7 @@ OneSampleTTest(iate_result.tmle) ### Composing Estimands -By leveraging the multivariate Central Limit Theorem and Julia's automatic differentiation facilities, we can actually compute any new parameter estimate from a set of already estimated parameters. By default, TMLE.jl will use [Zygote](https://fluxml.ai/Zygote.jl/latest/) but since we are using [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) you can change the backend to your favorite AD system. +By leveraging the multivariate Central Limit Theorem and Julia's automatic differentiation facilities, we can actually compute any new estimand estimate from a set of already estimated estimands. By default, TMLE.jl will use [Zygote](https://fluxml.ai/Zygote.jl/latest/) but since we are using [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) you can change the backend to your favorite AD system. For instance, by definition of the ATE, we should be able to retrieve ``ATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1}`` by composing ``CM_{T_1=1, T_2=1} - CM_{T_1=0, T_2=0}``. We already have almost all of the pieces, we just need an estimate for ``CM_{T_1=0, T_2=0}``, let's get it. @@ -216,7 +216,7 @@ var(composed_ate_result), var(ate_result.tmle) ### Reading Estimands from YAML files -It may be useful to read and write parameters to text files. We provide this functionality using the YAML format and the `parameters_from_yaml` and `parameters_to_yaml` functions. Since YAML is quite sensitive to identation, it is not recommended to write those files by hand. As an illustration, an example of such a file is given below: +It may be useful to read and write estimands to text files. We provide this functionality using the YAML format and the `estimands_from_yaml` and `estimands_to_yaml` functions. Since YAML is quite sensitive to identation, it is not recommended to write those files by hand. As an illustration, an example of such a file is given below: ```yaml Estimands: @@ -244,7 +244,7 @@ Estimands: ## Using the cache -Oftentimes, we are interested in multiple parameters, or would like to investigate how our estimator is affected by changes in the nuisance parameters specification. In many cases, as long as the dataset under study is the same, it is possible to save some computational time by caching the previously learnt nuisance parameters. We describe below how TMLE.jl proposes to do that in some common scenarios. For that purpose let us add a new target variable (which is simply random noise) to our dataset: +Oftentimes, we are interested in multiple estimands, or would like to investigate how our estimator is affected by changes in the nuisance estimands specification. In many cases, as long as the dataset under study is the same, it is possible to save some computational time by caching the previously learnt nuisance estimands. We describe below how TMLE.jl proposes to do that in some common scenarios. For that purpose let us add a new target variable (which is simply random noise) to our dataset: ```@example user-guide dataset.Ynew = rand(1000) @@ -253,7 +253,7 @@ nothing # hide ### Scenario 1: Changing the treatment values -Let us say we are interested in two ATE parameters: ``ATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1}`` and ``ATE_{T_1=1 \rightarrow 0, T_2=0 \rightarrow 1}`` (Notice how the setting for $T_1$ has changed). +Let us say we are interested in two ATE estimands: ``ATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1}`` and ``ATE_{T_1=1 \rightarrow 0, T_2=0 \rightarrow 1}`` (Notice how the setting for $T_1$ has changed). Let us start afresh an compute the first ATE: @@ -268,7 +268,7 @@ ate_result₁, cache = tmle(Ψ, η_spec, dataset) nothing # hide ``` -Notice the logs are informing you of all the nuisance parameters that are being fitted. +Notice the logs are informing you of all the nuisance estimands that are being fitted. Let us now investigate the second ATE by using the cache: @@ -283,11 +283,11 @@ ate_result₂, cache = tmle!(cache, Ψ) nothing # hide ``` -You should see that the logs are actually now telling you which nuisance parameters have been reused, i.e. all of them, only the targeting step needs to be done! This is because we already had nuisance estimators that matched our target parameter. +You should see that the logs are actually now telling you which nuisance estimands have been reused, i.e. all of them, only the targeting step needs to be done! This is because we already had nuisance estimators that matched our target estimand. ### Scenario 2: Changing the target -Let us now imagine that we are interested in another target: $Ynew$, we can say so by defining a new parameter and running the TMLE procedure using the cache: +Let us now imagine that we are interested in another target: $Ynew$, we can say so by defining a new estimand and running the TMLE procedure using the cache: ```@example user-guide Ψ = ATE( @@ -300,11 +300,11 @@ ate_result₃, cache = tmle!(cache, Ψ) nothing # hide ``` -As you can see, only $Q$ has been updated because the existing cached $G$ already matches our target parameter and cane be reused. +As you can see, only $Q$ has been updated because the existing cached $G$ already matches our target estimand and cane be reused. -### Scenario 3: Changing the nuisance parameters specification +### Scenario 3: Changing the nuisance estimands specification -Another common situation is to try a new model for a given nuisance parameter (or both). Here we can try a new regularization parameter for our logistic regression: +Another common situation is to try a new model for a given nuisance estimand (or both). Here we can try a new regularization estimand for our logistic regression: ```@example user-guide η_spec = NuisanceSpec( @@ -321,4 +321,4 @@ Since we have only updated $G$'s specification, only this model is fitted again. ### General behaviour Any change to either the `Estimand` of interest or the `NuisanceSpec` structures will trigger an update of the cache. If you have a predefined -list of parameters to estimate, you can optimize the estimation order of those parameters by calling `optimize_ordering!` or `optimize_ordering`. This will order the parameters in order to save the number of required nuisance parameters fits. +list of estimands to estimate, you can optimize the estimation order of those estimands by calling `optimize_ordering!` or `optimize_ordering`. This will order the estimands in order to save the number of required nuisance estimands fits. diff --git a/examples/introduction_to_targeted_learning.jl b/examples/introduction_to_targeted_learning.jl index e30acc22..a7f47a95 100644 --- a/examples/introduction_to_targeted_learning.jl +++ b/examples/introduction_to_targeted_learning.jl @@ -114,26 +114,26 @@ p^{do(height=h)}(score=s) = \int_g p(score=s | height=h, gender=g) For the sake of this tutorial we will assume that gender is the only confounding variable and can now move to the statistical estimation procedure. -## The conceptual shift: from statistical models to parameters +## The conceptual shift: from statistical models to estimands Traditional estimation methods start from a conveniently chosen parametric statistical model, proceed with maximum likelihood estimation -and finally **interpret** one of the model's parameter as the seeked effect size. For instance, a linear model is often assumed and in our scenario +and finally **interpret** one of the model's estimand as the seeked effect size. For instance, a linear model is often assumed and in our scenario could be formalized as : ```math Y = \alpha + \beta \cdot X + \epsilon , \epsilon \sim \mathcal{N}(0, \sigma^2) ``` -where ``\alpha``, ``\beta`` and ``\sigma^2`` are parameters to be estimated. In this case, ``\beta`` would be understood as +where ``\alpha``, ``\beta`` and ``\sigma^2`` are estimands to be estimated. In this case, ``\beta`` would be understood as the effect of X on Y. The problem with this approach is that if the model is wrong there is no guarantee that ``\beta`` will actually correspond to any effect size at all. We refer to such approaches as model-based -because the model is the starting point of the analysis. Targeted Learning on the other hand, can be considered parameter-based +because the model is the starting point of the analysis. Targeted Learning on the other hand, can be considered estimand-based because the starting point of the analysis is the question of interest. But how do you formulate a question without a model? The reality is that it requires some mathematical abstraction, which can be intimidating at first and keep you astray. Please do not, in fact you don't need to understand the mathematical details to use this package, and if you are trying to estimate some standard -effect size, there is a high chance the parameter you are looking for is the Average Treatment Effect (ATE). Conceptually, you can think +effect size, there is a high chance the estimand you are looking for is the Average Treatment Effect (ATE). Conceptually, you can think of the observed data as being generated by some complicated process which we will denote by ``P_0``. The ``0`` subscript reminds us -that this the ground truth, some unknown process that we can only witness through the observed data. A parameter, or question of interest is +that this the ground truth, some unknown process that we can only witness through the observed data. A estimand, or question of interest is simply a summary of this highly complicated ``P_0``. In our climbing example, we could ask the following question: "What average improvement in climbing grades would someone see after a year if they started climbing 3 times a week as compared to climbing only once a week". As you can see, the question is quite precise and the answer is expected to be a single number, for instance it could be 0, 1, 2, 3 grades... diff --git a/examples/super_learning.jl b/examples/super_learning.jl index 6756664e..96d7bb04 100644 --- a/examples/super_learning.jl +++ b/examples/super_learning.jl @@ -11,7 +11,7 @@ Super Learner will perform at least as well as its best performing submodel. Why is it important for Targeted Learning? The short answer is that the consistency (convergence in probability) of the targeted -estimator depends on the consistency of at least one of the nuisance parameters: ``Q_0`` or ``G_0`` (see [Mathematical setting](@ref) ). +estimator depends on the consistency of at least one of the nuisance estimands: ``Q_0`` or ``G_0`` (see [Mathematical setting](@ref) ). By only using unrealistic models like linear models, we have little chance of satisfying the above criterion. Super Learning is a data driven way to leverage a diverse set of models and build the best performing estimator for both ``Q_0`` or ``G_0``. @@ -176,7 +176,7 @@ nothing # hide #= ## Targeted estimation -Let us move to the targeted estimation step itself. We define the target parameter (the ATE) and the nuisance parameters specification: +Let us move to the targeted estimation step itself. We define the target estimand (the ATE) and the nuisance estimands specification: =# Ψ = ATE( target = :Y, diff --git a/src/TMLE.jl b/src/TMLE.jl index d0fd641f..f2b403b7 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -28,14 +28,14 @@ export NuisanceSpec, TMLECache, update!, CM, ATE, IATE export tmle, tmle! export var, estimate, initial_estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose -export parameters_from_yaml, parameters_to_yaml, optimize_ordering, optimize_ordering! +export estimands_from_yaml, estimands_to_yaml, optimize_ordering, optimize_ordering! # ############################################################################# # INCLUDES # ############################################################################# include("treatment_transformer.jl") -include("parameters.jl") +include("estimands.jl") include("utils.jl") include("cache.jl") include("estimate.jl") @@ -77,7 +77,7 @@ function run_precompile_workload() LinearBinaryClassifier() ) r, cache = tmle(Ψ, η_spec, dataset, verbosity=0) - continuous_parameters = [ + continuous_estimands = [ ATE( target =:Y₁, treatment=(T₁=(case=true, control=false),), @@ -92,12 +92,12 @@ function run_precompile_workload() # ) ] cache = TMLECache(dataset) - for Ψ in continuous_parameters + for Ψ in continuous_estimands tmle!(cache, Ψ, η_spec, verbosity=0) end # Precompiling with a binary target - binary_parameters = [ + binary_estimands = [ ATE( target =:Y₂, treatment=(T₁=(case=true, control=false),), @@ -116,7 +116,7 @@ function run_precompile_workload() LinearBinaryClassifier() ) - for Ψ in binary_parameters + for Ψ in binary_estimands tmle!(cache, Ψ, η_spec, verbosity=0) end diff --git a/src/cache.jl b/src/cache.jl index abf69750..78e4b235 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -1,11 +1,11 @@ """ This structure is used as a cache for: - data: data changes - - η: The nuisance parameters estimators - - Ψ: The parameter of interest - - η_spec: The specification of the learning algorithms used to estimate the nuisance parameters + - η: The nuisance estimands estimators + - Ψ: The estimand of interest + - η_spec: The specification of the learning algorithms used to estimate the nuisance estimands -In many cases, if some of those fields do not change, the nuisance parameters do not have to be estimated again hence +In many cases, if some of those fields do not change, the nuisance estimands do not have to be estimated again hence saving computations. """ mutable struct TMLECache @@ -110,7 +110,7 @@ end const TMLE_ARGS_DOCS = """ # Arguments -- Ψ: The parameter of interest +- Ψ: The estimand of interest - η_spec: The specification for learning `Q_0` and `G_0` - dataset: A tabular dataset respecting the Table.jl interface - verbosity: The logging level @@ -139,12 +139,12 @@ Runs TMLE using the provided TMLECache. $TMLE_ARGS_DOCS """ function tmle!(cache::TMLECache; verbosity=1, threshold=1e-8, weighted_fluctuation=false) - # Initial fit of the nuisance parameters - verbosity >= 1 && @info "Fitting the nuisance parameters..." + # Initial fit of the nuisance estimands + verbosity >= 1 && @info "Fitting the nuisance estimands..." TMLE.fit_nuisance!(cache, verbosity=verbosity) # TMLE step - verbosity >= 1 && @info "Targeting the nuisance parameters..." + verbosity >= 1 && @info "Targeting the nuisance estimands..." tmle_step!(cache, verbosity=verbosity, threshold=threshold, weighted_fluctuation=weighted_fluctuation) # Estimation results after TMLE @@ -208,7 +208,7 @@ end """ fit_nuisance!(cache::TMLECache; verbosity=1, mach_cache=false) -Fits the nuisance parameters η on the dataset using the specifications from η_spec +Fits the nuisance estimands η on the dataset using the specifications from η_spec and the variables defined by Ψ. """ function fit_nuisance!(cache::TMLECache; verbosity=1) diff --git a/src/configuration.jl b/src/configuration.jl index 2f3b95b6..38ec485a 100644 --- a/src/configuration.jl +++ b/src/configuration.jl @@ -8,29 +8,29 @@ function param_to_dict(Ψ::Estimand) end """ - parameters_from_yaml(filepath) + estimands_from_yaml(filepath) Read estimation `Estimands` from YAML file. """ -function parameters_from_yaml(filepath) +function estimands_from_yaml(filepath) yaml_dict = YAML.load_file(filepath; dicttype=Dict{String, Any}) params_dicts = yaml_dict["Estimands"] nparams = size(params_dicts, 1) - parameters = Vector{Estimand}(undef, nparams) + estimands = Vector{Estimand}(undef, nparams) for index in 1:nparams d = params_dicts[index] type = pop!(d, "type") - parameters[index] = from_dict(eval(Meta.parse(type)), d) + estimands[index] = from_dict(eval(Meta.parse(type)), d) end - return parameters + return estimands end """ - parameters_to_yaml(filepath, parameters::Vector{<:Estimand}) + estimands_to_yaml(filepath, estimands::Vector{<:Estimand}) Write estimation `Estimands` to YAML file. """ -function parameters_to_yaml(filepath, parameters) - d = Dict("Estimands" => [param_to_dict(Ψ) for Ψ in parameters]) +function estimands_to_yaml(filepath, estimands) + d = Dict("Estimands" => [param_to_dict(Ψ) for Ψ in estimands]) YAML.write_file(filepath, d) end \ No newline at end of file diff --git a/src/parameters.jl b/src/estimands.jl similarity index 93% rename from src/parameters.jl rename to src/estimands.jl index 9b878760..7f62bb0d 100644 --- a/src/parameters.jl +++ b/src/estimands.jl @@ -187,7 +187,7 @@ end """ NuisanceSpec(Q, G; H=encoder(), F=Q_model(target_scitype(Q))) -Specification of the nuisance parameters to be learnt. +Specification of the nuisance estimands to be learnt. # Arguments: @@ -257,23 +257,23 @@ function param_key(Ψ::Estimand) end """ - optimize_ordering!(parameters::Vector{<:Estimand}) + optimize_ordering!(estimands::Vector{<:Estimand}) -Reorders the given parameters so that most nuisance parameters fits can be +Reorders the given estimands so that most nuisance estimands fits can be reused. Given the assumed causal graph: $causal_graph and the requirements to estimate both p(T|W) and E[Y|W, T, C]. -A natural ordering of the parameters in order to save computations is given by the +A natural ordering of the estimands in order to save computations is given by the following variables ordering: (W, T, Y, C) """ -optimize_ordering!(parameters::Vector{<:Estimand}) = sort!(parameters, by=param_key) +optimize_ordering!(estimands::Vector{<:Estimand}) = sort!(estimands, by=param_key) """ - optimize_ordering(parameters::Vector{<:Estimand}) + optimize_ordering(estimands::Vector{<:Estimand}) See [`optimize_ordering!`](@ref) """ -optimize_ordering(parameters::Vector{<:Estimand}) = sort(parameters, by=param_key) \ No newline at end of file +optimize_ordering(estimands::Vector{<:Estimand}) = sort(estimands, by=param_key) \ No newline at end of file diff --git a/src/estimate.jl b/src/estimate.jl index b3516726..910ddfb0 100644 --- a/src/estimate.jl +++ b/src/estimate.jl @@ -12,7 +12,7 @@ struct ComposedTMLE{T<:AbstractFloat} <: AbstractTMLE end struct TMLEResult{P <: Estimand, T<:AbstractFloat} - parameter::P + estimand::P tmle::ALEstimate{T} onestep::ALEstimate{T} initial::T diff --git a/test/cache.jl b/test/cache.jl index 44c9f1e7..8d486f20 100644 --- a/test/cache.jl +++ b/test/cache.jl @@ -100,14 +100,14 @@ end for colname in keys(cache.data[:no_missing]) cache.data[:no_missing][colname] === cache.data[:source][colname] end - # Pretend the cache nuisance parameters have been estimated + # Pretend the cache nuisance estimands have been estimated fakefill!(cache) @test cache.η.H !== nothing @test cache.η.G !== nothing @test cache.η.Q !== nothing @test cache.η.F !== nothing # New treatment configuration and new confounders set - # Nuisance parameters can be reused + # Nuisance estimands can be reused Ψ = ATE( target=:y, treatment=(T₁=(case=0, control=1),), @@ -241,7 +241,7 @@ end MLJModels.ConstantRegressor(), ConstantClassifier() ) - # Nuisance parameter estimation + # Nuisance estimand estimation tmle_result, cache = tmle(Ψ, η_spec, dataset; verbosity=0); # The inital estimator for Q is a constant predictor, # its prediction is the same for each counterfactual diff --git a/test/configuration.jl b/test/configuration.jl index a4663c0f..b14a8a70 100644 --- a/test/configuration.jl +++ b/test/configuration.jl @@ -5,8 +5,8 @@ using TMLE using Serialization @testset "Test configurations" begin - param_file = "parameters_sample.yaml" - parameters = [ + param_file = "estimands_sample.yaml" + estimands = [ IATE(; target=:Y1, treatment=(T2 = (case = 1, control = 0), T1 = (case = 1, control = 0)), @@ -22,14 +22,14 @@ using Serialization CM(;target=:Y3, treatment=(T3 = "AC", T1 = "CC"), confounders=[:W1], covariates=Symbol[]) ] # Test YAML Serialization - parameters_to_yaml(param_file, parameters) - new_parameters = parameters_from_yaml(param_file) - @test new_parameters == parameters + estimands_to_yaml(param_file, estimands) + new_estimands = estimands_from_yaml(param_file) + @test new_estimands == estimands rm(param_file) # Test basic Serialization - serialize(param_file, parameters) - new_parameters = deserialize(param_file) - @test new_parameters == parameters + serialize(param_file, estimands) + new_estimands = deserialize(param_file) + @test new_estimands == estimands rm(param_file) end diff --git a/test/parameters.jl b/test/estimands.jl similarity index 94% rename from test/parameters.jl rename to test/estimands.jl index 4df32712..6cc1dd92 100644 --- a/test/parameters.jl +++ b/test/estimands.jl @@ -37,7 +37,7 @@ using TMLE end @testset "Test optimize_ordering" begin - parameters = [ + estimands = [ ATE( target=:Y, treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), @@ -83,7 +83,7 @@ end ), ] # Non mutating function - ordered_parameters = optimize_ordering(parameters) + ordered_estimands = optimize_ordering(estimands) expected_ordering = [ ATE(:Y, (T₁ = (case = 1, control = 0),), [:W₁], Symbol[]), CM(:Y, (T₁ = 0,), [:W₁], Symbol[]), @@ -94,10 +94,10 @@ end ATE(:Y₂, (T₁ = (case = 1, control = 0), T₂ = (case = "AC", control = "CC")), [:W₁, :W₂], Symbol[]), ATE(:Y₂, (T₁ = (case = 0, control = 1), T₂ = (case = "AC", control = "CC")), [:W₁, :W₂], [:C₂]) ] - @test ordered_parameters == expected_ordering + @test ordered_estimands == expected_ordering # Mutating function - optimize_ordering!(parameters) - @test parameters == ordered_parameters + optimize_ordering!(estimands) + @test estimands == ordered_estimands end @testset "Test structs are concrete types" begin diff --git a/test/runtests.jl b/test/runtests.jl index acff0680..bf159d6c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,7 @@ using Test @test include("double_robustness_iate.jl") @test include("3points_interactions.jl") @test include("warm_restart.jl") - @test include("parameters.jl") + @test include("estimands.jl") @test include("missing_management.jl") @test include("composition.jl") @test include("treatment_transformer.jl") diff --git a/test/utils.jl b/test/utils.jl index 64dac943..54f51f16 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -12,8 +12,8 @@ using MLJGLMInterface: LinearBinaryClassifier using MLJLinearModels using MLJModels -function test_params_match(parameters, expected_params) - for (param, expected_param) in zip(parameters, expected_params) +function test_params_match(estimands, expected_params) + for (param, expected_param) in zip(estimands, expected_params) @test typeof(param) == typeof(expected_param) @test param.target == expected_param.target @test param.treatment == expected_param.treatment diff --git a/test/warm_restart.jl b/test/warm_restart.jl index de0c078d..4436a5f9 100644 --- a/test/warm_restart.jl +++ b/test/warm_restart.jl @@ -50,37 +50,37 @@ table_types = (Tables.columntable, DataFrame) @testset "Test Warm restart: ATE single treatment, $tt" for tt in table_types dataset = tt(build_dataset(;n=50_000)) - # Define the parameter of interest + # Define the estimand of interest Ψ = ATE( target=:y₁, treatment=(T₁=(case=true, control=false),), confounders=[:W₁, :W₂], covariates=[:C₁] ) - # Define the nuisance parameters specification + # Define the nuisance estimands specification η_spec = NuisanceSpec( LinearRegressor(), LogisticClassifier(lambda=0) ) # Run TMLE log_sequence = ( - (:info, "Fitting the nuisance parameters..."), + (:info, "Fitting the nuisance estimands..."), (:info, "→ Fitting P(T|W)"), (:info, "→ Fitting Encoder"), (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance parameters..."), + (:info, "Targeting the nuisance estimands..."), (:info, "Done.") ) tmle_result, cache = @test_logs log_sequence... tmle(Ψ, η_spec, dataset; verbosity=1); # The TMLE covers the ground truth but the initial estimate does not Ψ₀ = -1 - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Update the treatment specification - # Nuisance parameters should not fitted again + # Nuisance estimands should not fitted again Ψ = ATE( target=:y₁, treatment=(T₁=(case=false, control=true),), @@ -88,16 +88,16 @@ table_types = (Tables.columntable, DataFrame) covariates=[:C₁] ) log_sequence = ( - (:info, "Fitting the nuisance parameters..."), + (:info, "Fitting the nuisance estimands..."), (:info, "→ Reusing previous P(T|W)"), (:info, "→ Reusing previous Encoder"), (:info, "→ Reusing previous E[Y|X]"), - (:info, "Targeting the nuisance parameters..."), + (:info, "Targeting the nuisance estimands..."), (:info, "Done.") ) tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); Ψ₀ = 1 - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -109,16 +109,16 @@ table_types = (Tables.columntable, DataFrame) confounders=[:W₁, :W₂], ) log_sequence = ( - (:info, "Fitting the nuisance parameters..."), + (:info, "Fitting the nuisance estimands..."), (:info, "→ Reusing previous P(T|W)"), (:info, "→ Reusing previous Encoder"), (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance parameters..."), + (:info, "Targeting the nuisance estimands..."), (:info, "Done.") ) tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1, weighted_fluctuation=true); Ψ₀ = -1 - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -131,16 +131,16 @@ table_types = (Tables.columntable, DataFrame) confounders=[:W₁, :W₂], ) log_sequence = ( - (:info, "Fitting the nuisance parameters..."), + (:info, "Fitting the nuisance estimands..."), (:info, "→ Fitting P(T|W)"), (:info, "→ Fitting Encoder"), (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance parameters..."), + (:info, "Targeting the nuisance estimands..."), (:info, "Done.") ) tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); Ψ₀ = 0.5 - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -155,15 +155,15 @@ table_types = (Tables.columntable, DataFrame) confounders=[:W₁], ) log_sequence = ( - (:info, "Fitting the nuisance parameters..."), + (:info, "Fitting the nuisance estimands..."), (:info, "→ Fitting P(T|W)"), (:info, "→ Reusing previous Encoder"), (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance parameters..."), + (:info, "Targeting the nuisance estimands..."), (:info, "Done.") ) tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1, weighted_fluctuation=true); - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Change the target @@ -174,16 +174,16 @@ table_types = (Tables.columntable, DataFrame) confounders=[:W₁], ) log_sequence = ( - (:info, "Fitting the nuisance parameters..."), + (:info, "Fitting the nuisance estimands..."), (:info, "→ Reusing previous P(T|W)"), (:info, "→ Reusing previous Encoder"), (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance parameters..."), + (:info, "Targeting the nuisance estimands..."), (:info, "Done.") ) tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); Ψ₀ = 10 - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₂) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -197,7 +197,7 @@ table_types = (Tables.columntable, DataFrame) covariates=[:C₁] ) tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₂) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -205,21 +205,21 @@ end @testset "Test Warm restart: ATE multiple treatment, $tt" for tt in table_types dataset = tt(build_dataset(;n=50_000)) - # Define the parameter of interest + # Define the estimand of interest Ψ = ATE( target=:y₁, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), confounders=[:W₁, :W₂], covariates=[:C₁] ) - # Define the nuisance parameters specification + # Define the nuisance estimands specification η_spec = NuisanceSpec( LinearRegressor(), LogisticClassifier(lambda=0) ) tmle_result, cache = tmle(Ψ, η_spec, dataset; verbosity=0); Ψ₀ = -0.5 - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -232,16 +232,16 @@ end covariates=[:C₁] ) log_sequence = ( - (:info, "Fitting the nuisance parameters..."), + (:info, "Fitting the nuisance estimands..."), (:info, "→ Reusing previous P(T|W)"), (:info, "→ Reusing previous Encoder"), (:info, "→ Reusing previous E[Y|X]"), - (:info, "Targeting the nuisance parameters..."), + (:info, "Targeting the nuisance estimands..."), (:info, "Done.") ) tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); Ψ₀ = -1.5 - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -261,7 +261,7 @@ end ) tmle_result, cache = tmle(Ψ, η_spec, dataset; verbosity=0); Ψ₀ = 3 - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -274,16 +274,16 @@ end covariates=[:C₁] ) log_sequence = ( - (:info, "Fitting the nuisance parameters..."), + (:info, "Fitting the nuisance estimands..."), (:info, "→ Reusing previous P(T|W)"), (:info, "→ Reusing previous Encoder"), (:info, "→ Reusing previous E[Y|X]"), - (:info, "Targeting the nuisance parameters..."), + (:info, "Targeting the nuisance estimands..."), (:info, "Done.") ) tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); Ψ₀ = 2.5 - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -296,16 +296,16 @@ end covariates=[:C₁] ) log_sequence = ( - (:info, "Fitting the nuisance parameters..."), + (:info, "Fitting the nuisance estimands..."), (:info, "→ Reusing previous P(T|W)"), (:info, "→ Reusing previous Encoder"), (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance parameters..."), + (:info, "Targeting the nuisance estimands..."), (:info, "Done.") ) tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); Ψ₀ = 1 - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₂) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -318,16 +318,16 @@ end covariates=[:C₁] ) log_sequence = ( - (:info, "Fitting the nuisance parameters..."), + (:info, "Fitting the nuisance estimands..."), (:info, "→ Fitting P(T|W)"), (:info, "→ Fitting Encoder"), (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance parameters..."), + (:info, "Targeting the nuisance estimands..."), (:info, "Done.") ) tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); Ψ₀ = 11 - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₂) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -347,7 +347,7 @@ end ) tmle_result, cache = tmle(Ψ, η_spec, dataset; verbosity=0); Ψ₀ = -1 - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₃) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -359,15 +359,15 @@ end confounders=[:W₁, :W₂], ) log_sequence = ( - (:info, "Fitting the nuisance parameters..."), + (:info, "Fitting the nuisance estimands..."), (:info, "→ Reusing previous P(T|W)"), (:info, "→ Reusing previous Encoder"), (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance parameters..."), + (:info, "Targeting the nuisance estimands..."), (:info, "Done.") ) tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1, weighted_fluctuation=true); - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -378,16 +378,16 @@ end confounders=[:W₁, :W₂], ) log_sequence = ( - (:info, "Fitting the nuisance parameters..."), + (:info, "Fitting the nuisance estimands..."), (:info, "→ Reusing previous P(T|W)"), (:info, "→ Reusing previous Encoder"), (:info, "→ Reusing previous E[Y|X]"), - (:info, "Targeting the nuisance parameters..."), + (:info, "Targeting the nuisance estimands..."), (:info, "Done.") ) tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); Ψ₀ = - Ψ₀ - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₃) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -396,21 +396,21 @@ end @testset "Test Warm restart: Both Ψ and η changed" begin dataset = build_dataset(;n=50_000) - # Define the parameter of interest + # Define the estimand of interest Ψ = ATE( target=:y₁, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), confounders=[:W₁, :W₂], covariates=[:C₁] ) - # Define the nuisance parameters specification + # Define the nuisance estimands specification η_spec = NuisanceSpec( LinearRegressor(), LogisticClassifier(lambda=0) ) tmle_result, cache = tmle(Ψ, η_spec, dataset; verbosity=0); Ψ₀ = -0.5 - @test tmle_result.parameter == Ψ + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -421,22 +421,22 @@ end confounders=[:W₁, :W₂], covariates=[:C₁] ) - # Define the nuisance parameters specification + # Define the nuisance estimands specification η_spec_new = NuisanceSpec( LinearRegressor(), LogisticClassifier(lambda=0.1) ) log_sequence = ( - (:info, "Fitting the nuisance parameters..."), + (:info, "Fitting the nuisance estimands..."), (:info, "→ Reusing previous P(T|W)"), (:info, "→ Reusing previous Encoder"), (:info, "→ Reusing previous E[Y|X]"), - (:info, "Targeting the nuisance parameters..."), + (:info, "Targeting the nuisance estimands..."), (:info, "Done.") ) tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψnew, η_spec; verbosity=1); Ψ₀ = 0.5 - @test tmle_result.parameter == Ψnew + @test tmle_result.estimand == Ψnew test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(cache; target_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) From 2d57de23f94d373609dfa5cef39d2bf273bac1dd Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 2 Aug 2023 14:02:31 +0100 Subject: [PATCH 003/151] rename target to outcome --- docs/src/index.md | 4 +- docs/src/user_guide.md | 30 ++++---- examples/double_robustness.jl | 2 +- examples/introduction_to_targeted_learning.jl | 8 +- examples/super_learning.jl | 2 +- src/TMLE.jl | 12 +-- src/cache.jl | 8 +- src/estimands.jl | 44 +++++------ test/3points_interactions.jl | 4 +- test/cache.jl | 24 +++--- test/composition.jl | 4 +- test/configuration.jl | 6 +- test/double_robustness_ate.jl | 48 ++++++------ test/double_robustness_iate.jl | 36 ++++----- test/estimands.jl | 24 +++--- test/fit_nuisance.jl | 32 ++++---- test/helper_fns.jl | 10 +-- test/missing_management.jl | 4 +- test/non_regression_test.jl | 2 +- test/utils.jl | 16 ++-- test/warm_restart.jl | 74 +++++++++---------- 21 files changed, 197 insertions(+), 197 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 4ee9c188..1074b245 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -47,7 +47,7 @@ Specifying which machine-learning method to use for each of those nuisance estim ### The Conditional Mean -This is simply the expected value of the target $Y$ when the treatment is set to $t$. +This is simply the expected value of the outcome $Y$ when the treatment is set to $t$. ```math CM_t(P) = \mathbb{E}[\mathbb{E}[Y|T=t, W]] @@ -124,7 +124,7 @@ And say we are interested in the $ATE_{0 \rightarrow 1}(P_0)$: ```@example quick-start Ψ = ATE( - target = :Y, + outcome = :Y, treatment = (T=(case=true, control = false),), confounders = [:W] ) diff --git a/docs/src/user_guide.md b/docs/src/user_guide.md index 44ba6ed7..7e73984d 100644 --- a/docs/src/user_guide.md +++ b/docs/src/user_guide.md @@ -71,7 +71,7 @@ The `NuisanceSpec` struct also holds a specification for the `OneHotEncoder` nec ### The Conditional mean -We are now ready to move to the definition of the estimands of interest. The most basic type of estimand is the conditional mean of the target given the treatment: +We are now ready to move to the definition of the estimands of interest. The most basic type of estimand is the conditional mean of the outcome given the treatment: ```math CM_t(P) = \mathbb{E}[\mathbb{E}[Y|T=t, W]] @@ -81,7 +81,7 @@ The treatment does not have to be restricted to a single variable, we can define ```@example user-guide Ψ = CM( - target = :Y, + outcome = :Y, treatment = (T₁=true, T₂=true), confounders = [:W₁, :W₂] ) @@ -142,7 +142,7 @@ Let's see what the TMLE tells us: ```@example user-guide Ψ = ATE( - target = :Y, + outcome = :Y, treatment = (T₁=(case=true, control=false), T₂=(case=true, control=false)), confounders = [:W₁, :W₂] ) @@ -171,7 +171,7 @@ and run: ```@example user-guide Ψ = IATE( - target = :Y, + outcome = :Y, treatment = (T₁=(case=true, control=false), T₂=(case=true, control=false)), confounders = [:W₁, :W₂] ) @@ -189,7 +189,7 @@ For instance, by definition of the ATE, we should be able to retrieve ``ATE_{T_1 ```@example user-guide Ψ = CM( - target = :Y, + outcome = :Y, treatment = (T₁=false, T₂=false), confounders = [:W₁, :W₂] ) @@ -220,7 +220,7 @@ It may be useful to read and write estimands to text files. We provide this func ```yaml Estimands: - - target: Y1 + - outcome: Y1 treatment: (T2 = (case = 1, control = 0), T1 = (case = 1, control = 0)) confounders: - W1 @@ -228,13 +228,13 @@ Estimands: covariates: - C1 type: IATE - - target: Y3 + - outcome: Y3 treatment: (T3 = (case = 1, control = 0), T1 = (case = "AC", control = "CC")) confounders: - W1 covariates: [] type: ATE - - target: Y3 + - outcome: Y3 treatment: (T3 = "AC", T1 = "CC") confounders: - W1 @@ -244,7 +244,7 @@ Estimands: ## Using the cache -Oftentimes, we are interested in multiple estimands, or would like to investigate how our estimator is affected by changes in the nuisance estimands specification. In many cases, as long as the dataset under study is the same, it is possible to save some computational time by caching the previously learnt nuisance estimands. We describe below how TMLE.jl proposes to do that in some common scenarios. For that purpose let us add a new target variable (which is simply random noise) to our dataset: +Oftentimes, we are interested in multiple estimands, or would like to investigate how our estimator is affected by changes in the nuisance estimands specification. In many cases, as long as the dataset under study is the same, it is possible to save some computational time by caching the previously learnt nuisance estimands. We describe below how TMLE.jl proposes to do that in some common scenarios. For that purpose let us add a new outcome variable (which is simply random noise) to our dataset: ```@example user-guide dataset.Ynew = rand(1000) @@ -259,7 +259,7 @@ Let us start afresh an compute the first ATE: ```@example user-guide Ψ = ATE( - target = :Y, + outcome = :Y, treatment = (T₁=(case=true, control=false), T₂=(case=true, control=false)), confounders = [:W₁, :W₂] ) @@ -274,7 +274,7 @@ Let us now investigate the second ATE by using the cache: ```@example user-guide Ψ = ATE( - target = :Y, + outcome = :Y, treatment = (T₁=(case=false, control=true), T₂=(case=true, control=false)), confounders = [:W₁, :W₂] ) @@ -283,15 +283,15 @@ ate_result₂, cache = tmle!(cache, Ψ) nothing # hide ``` -You should see that the logs are actually now telling you which nuisance estimands have been reused, i.e. all of them, only the targeting step needs to be done! This is because we already had nuisance estimators that matched our target estimand. +You should see that the logs are actually now telling you which nuisance estimands have been reused, i.e. all of them, only the targeting step needs to be done! This is because we already had nuisance estimators that matched our outcome estimand. -### Scenario 2: Changing the target +### Scenario 2: Changing the outcome -Let us now imagine that we are interested in another target: $Ynew$, we can say so by defining a new estimand and running the TMLE procedure using the cache: +Let us now imagine that we are interested in another outcome: $Ynew$, we can say so by defining a new estimand and running the TMLE procedure using the cache: ```@example user-guide Ψ = ATE( - target = :Ynew, + outcome = :Ynew, treatment = (T₁=(case=true, control=false), T₂=(case=true, control=false)), confounders = [:W₁, :W₂] ) diff --git a/examples/double_robustness.jl b/examples/double_robustness.jl index 7992c00e..d7ef7541 100644 --- a/examples/double_robustness.jl +++ b/examples/double_robustness.jl @@ -154,7 +154,7 @@ that we now have full coverage of the ground truth. function tmle_inference(data) Ψ = ATE( - target=:Y, + outcome=:Y, treatment=(Tcat=(case=1.0, control=0.0),), confounders=[:W] ) diff --git a/examples/introduction_to_targeted_learning.jl b/examples/introduction_to_targeted_learning.jl index a7f47a95..38f8d086 100644 --- a/examples/introduction_to_targeted_learning.jl +++ b/examples/introduction_to_targeted_learning.jl @@ -202,20 +202,20 @@ function counterfactualX(dataset, features, height) end end -results = DataFrame(height_high=[], height_low=[], target=[], features=[], estimate=[]) -for target in (:mean_score, :max_score) +results = DataFrame(height_high=[], height_low=[], outcome=[], features=[], estimate=[]) +for outcome in (:mean_score, :max_score) for features in [[:height, :sex], [:height]] mach = machine( Q̂_spec, dataset[!, features], - dataset[!, target] + dataset[!, outcome] ) fit!(mach, verbosity=0) for (low, high) in zip(height_buckets[1:end-1], height_buckets[2:end]) X_high = counterfactualX(dataset, features, high) X_low = counterfactualX(dataset, features, low) estimate = naive_estimate(mach, X_high, X_low) - push!(results, (high, low, target, join(features, :_), estimate)) + push!(results, (high, low, outcome, join(features, :_), estimate)) end end end diff --git a/examples/super_learning.jl b/examples/super_learning.jl index 96d7bb04..d6e7a1b3 100644 --- a/examples/super_learning.jl +++ b/examples/super_learning.jl @@ -179,7 +179,7 @@ nothing # hide Let us move to the targeted estimation step itself. We define the target estimand (the ATE) and the nuisance estimands specification: =# Ψ = ATE( - target = :Y, + outcome = :Y, treatment = (T=(case=true, control=false),), confounders = [:W₁, :W₂] ) diff --git a/src/TMLE.jl b/src/TMLE.jl index f2b403b7..29a66ddc 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -67,7 +67,7 @@ function run_precompile_workload() # all calls in this block will be precompiled, regardless of whether # they belong to your package or not (on Julia 1.8 and higher) Ψ = CM( - target =:Y₁, + outcome =:Y₁, treatment=(T₁=true,), confounders=[:W₁, :W₂], covariates = [:C₁] @@ -79,13 +79,13 @@ function run_precompile_workload() r, cache = tmle(Ψ, η_spec, dataset, verbosity=0) continuous_estimands = [ ATE( - target =:Y₁, + outcome =:Y₁, treatment=(T₁=(case=true, control=false),), confounders=[:W₁, :W₂], covariates = [:C₁] ), # IATE( - # target =:Y₁, + # outcome =:Y₁, # treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), # confounders=[:W₁, :W₂], # covariates = [:C₁] @@ -96,16 +96,16 @@ function run_precompile_workload() tmle!(cache, Ψ, η_spec, verbosity=0) end - # Precompiling with a binary target + # Precompiling with a binary outcome binary_estimands = [ ATE( - target =:Y₂, + outcome =:Y₂, treatment=(T₁=(case=true, control=false),), confounders=[:W₁, :W₂], covariates = [:C₁] ), # IATE( - # target =:Y₂, + # outcome =:Y₂, # treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), # confounders=[:W₁, :W₂], # covariates = [:C₁] diff --git a/src/cache.jl b/src/cache.jl index 78e4b235..fdc7b9d5 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -69,7 +69,7 @@ function update!(cache::TMLECache, Ψ::Estimand) cache.η.Q = nothing any_variable_changed = true end - if cache.Ψ.target != Ψ.target + if cache.Ψ.outcome != Ψ.outcome cache.η.Q = nothing any_variable_changed = true end @@ -242,7 +242,7 @@ function fit_nuisance!(cache::TMLECache; verbosity=1) end # Data X = Qinputs(η.H, cache.data[:no_missing], Ψ) - y = target(cache.data[:no_missing], Ψ) + y = outcome(cache.data[:no_missing], Ψ) # Fitting E[Y|X] log_fit(verbosity, "E[Y|X]") mach = machine(η_spec.Q, X, y, cache=η_spec.cache) @@ -268,7 +268,7 @@ function tmle_step!(cache::TMLECache; verbosity=1, threshold=1e-8, weighted_fluc threshold=threshold, weighted_fluctuation=weighted_fluctuation ) X = TMLE.fluctuation_input(covariate, offset) - y = TMLE.target(cache.data[:no_missing], cache.Ψ) + y = TMLE.outcome(cache.data[:no_missing], cache.Ψ) mach = machine(cache.η_spec.F, X, y, weights, cache=cache.η_spec.cache) MLJBase.fit!(mach, verbosity=verbosity-1) # Update cache @@ -325,7 +325,7 @@ This part of the gradient is evaluated on the original dataset. All quantities h """ function gradients_Y_X(cache::TMLECache) covariate = cache.data[:covariate] - y = float(target(cache.data[:no_missing], cache.Ψ)) + y = float(outcome(cache.data[:no_missing], cache.Ψ)) gradient_Y_Xᵢ = covariate .* (y .- expected_value(cache.data[:Q₀])) gradient_Y_X_fluct = covariate .* (y .- expected_value(cache.data[:Qfluct])) return gradient_Y_Xᵢ, gradient_Y_X_fluct diff --git a/src/estimands.jl b/src/estimands.jl index 7f62bb0d..54978aaa 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -5,7 +5,7 @@ const causal_graph = """ ## Notation: -- Y: target +- Y: outcome - T: treatment - W: confounders - C: covariates @@ -36,29 +36,29 @@ Mathematical definition: $causal_graph # Fields: - - target: A symbol identifying the target variable of interest + - outcome: A symbol identifying the outcome variable of interest - treatment: A NamedTuple linking each treatment variable to a value - - confounders: Confounding variables affecting both the target and the treatment - - covariates: Optional extra variables affecting the target only + - confounders: Confounding variables affecting both the outcome and the treatment + - covariates: Optional extra variables affecting the outcome only # Examples: ```julia CM₁ = CM( - target=:Y₁, + outcome=:Y₁, treatment=(T₁=1,), confounders=[:W₁, :W₂], covariates=[:C₁] ) CM₂ = CM( - target=:Y₂, + outcome=:Y₂, treatment=(T₁=1, T₂="A"), confounders=[:W₁], ) ``` """ @option struct CM <: Estimand - target::Symbol + outcome::Symbol treatment::NamedTuple confounders::Vector{Symbol} covariates::Vector{Symbol} = Symbol[] @@ -80,29 +80,29 @@ Mathematical definition: $causal_graph # Fields: - - target: A symbol identifying the target variable of interest + - outcome: A symbol identifying the outcome variable of interest - treatment: A NamedTuple linking each treatment variable to case/control values - - confounders: Confounding variables affecting both the target and the treatment - - covariates: Optional extra variables affecting the target only + - confounders: Confounding variables affecting both the outcome and the treatment + - covariates: Optional extra variables affecting the outcome only # Examples: ```julia ATE₁ = ATE( - target=:Y₁, + outcome=:Y₁, treatment=(T₁=(case=1, control=0),), confounders=[:W₁, :W₂], covariates=[:C₁] ) ATE₂ = ATE( - target=:Y₂, + outcome=:Y₂, treatment=(T₁=(case=1, control=0), T₂=(case="A", control="B")), confounders=[:W₁], ) ``` """ @option struct ATE <: Estimand - target::Symbol + outcome::Symbol treatment::NamedTuple confounders::Vector{Symbol} covariates::Vector{Symbol} = Symbol[] @@ -125,22 +125,22 @@ Mathematical definition for pairwise interaction: $causal_graph # Fields: - - target: A symbol identifying the target variable of interest + - outcome: A symbol identifying the outcome variable of interest - treatment: A NamedTuple linking each treatment variable to case/control values - - confounders: Confounding variables affecting both the target and the treatment - - covariates: Optional extra variables affecting the target only + - confounders: Confounding variables affecting both the outcome and the treatment + - covariates: Optional extra variables affecting the outcome only # Examples: ```julia IATE₁ = IATE( - target=:Y₁, + outcome=:Y₁, treatment=(T₁=(case=1, control=0), T₂=(case="A", control="B")), confounders=[:W₁], ) ``` """ @option struct IATE <: Estimand - target::Symbol + outcome::Symbol treatment::NamedTuple confounders::Vector{Symbol} covariates::Vector{Symbol} = Symbol[] @@ -215,8 +215,8 @@ covariates(dataset, Ψ) = selectcols(dataset, covariates(Ψ)) treatments(Ψ::Estimand) = collect(keys(Ψ.treatment)) treatments(dataset, Ψ) = selectcols(dataset, treatments(Ψ)) -target(Ψ::Estimand) = Ψ.target -target(dataset, Ψ) = Tables.getcolumn(dataset, target(Ψ)) +outcome(Ψ::Estimand) = Ψ.outcome +outcome(dataset, Ψ) = Tables.getcolumn(dataset, outcome(Ψ)) treatment_and_confounders(Ψ::Estimand) = vcat(confounders(Ψ), treatments(Ψ)) @@ -232,7 +232,7 @@ Qinputs(H, dataset, Ψ::Estimand) = merge( ) -allcolumns(Ψ::Estimand) = vcat(confounders_and_covariates(Ψ), treatments(Ψ), target(Ψ)) +allcolumns(Ψ::Estimand) = vcat(confounders_and_covariates(Ψ), treatments(Ψ), outcome(Ψ)) F_model(::Type{<:AbstractVector{<:MLJBase.Continuous}}) = LinearRegressor(fit_intercept=false, offsetcol = :offset) @@ -251,7 +251,7 @@ function param_key(Ψ::Estimand) return ( join(Ψ.confounders, "_"), join(keys(Ψ.treatment), "_"), - string(Ψ.target), + string(Ψ.outcome), join(Ψ.covariates, "_") ) end diff --git a/test/3points_interactions.jl b/test/3points_interactions.jl index b895779e..12c052a7 100644 --- a/test/3points_interactions.jl +++ b/test/3points_interactions.jl @@ -28,7 +28,7 @@ end @testset "Test 3-points interactions" begin dataset, Ψ₀ = make_dataset(;n=1000) Ψ = IATE( - target = :y, + outcome = :y, confounders = [:W], treatment = (T₁=(case=true, control=false), T₂=(case=true, control=false), T₃=(case=1, control=0)) ) @@ -39,7 +39,7 @@ end tmle_result, cache = tmle(Ψ, η_spec, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache, target_name=:y) + test_fluct_decreases_risk(cache, outcome_name=:y) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) tmle_result, _ = tmle!(cache, verbosity=0, weighted_fluctuation=true) diff --git a/test/cache.jl b/test/cache.jl index 8d486f20..6c7a8c6c 100644 --- a/test/cache.jl +++ b/test/cache.jl @@ -49,13 +49,13 @@ end @testset "Test check_treatment_values" begin cache = fakecache() Ψ = ATE( - target=:y, + outcome=:y, treatment=(T₁=(case=1, control=0),), confounders=[:W₁], ) TMLE.check_treatment_values(cache, Ψ) Ψ = ATE( - target=:y, + outcome=:y, treatment=( T₁=(case=1, control=0), T₂=(case=true, control=false),), @@ -66,7 +66,7 @@ end " in the dataset: [\"0\", \"1\"]")) TMLE.check_treatment_values(cache, Ψ) Ψ = CM( - target=:y, + outcome=:y, treatment=(T₃="CT",), confounders=[:W₁], ) @@ -79,7 +79,7 @@ end @test !isdefined(cache, :η_spec) Ψ = ATE( - target=:y, + outcome=:y, treatment=(T₁=(case=1, control=0),), confounders=[:W₁], covariates=[:C₁] @@ -109,7 +109,7 @@ end # New treatment configuration and new confounders set # Nuisance estimands can be reused Ψ = ATE( - target=:y, + outcome=:y, treatment=(T₁=(case=0, control=1),), confounders=[:W₁, :W₂], covariates=[:C₁] @@ -124,11 +124,11 @@ end @test length(cache.data[:no_missing][colname]) == 8 end - # Change the target + # Change the outcome # E[Y|X] must be refit fakefill!(cache) Ψ = ATE( - target=:ynew, + outcome=:ynew, treatment=(T₁=(case=0, control=1),), confounders=[:W₁, :W₂], covariates=[:C₁] @@ -147,7 +147,7 @@ end # E[Y|X] must be refit fakefill!(cache) Ψ = ATE( - target=:ynew, + outcome=:ynew, treatment=(T₁=(case=0, control=1),), confounders=[:W₁, :W₂], ) @@ -162,7 +162,7 @@ end # only F needs refit fakefill!(cache) Ψ = ATE( - target=:ynew, + outcome=:ynew, treatment=(T₁=(case=1, control=0),), confounders=[:W₁, :W₂], ) @@ -177,7 +177,7 @@ end # all η must be refit fakefill!(cache) Ψ = ATE( - target=:ynew, + outcome=:ynew, treatment=(T₂=(case=1, control=0),), confounders=[:W₁, :W₂], ) @@ -192,7 +192,7 @@ end @testset "Test update!(cache, η_spec)" begin cache = fakecache() Ψ = ATE( - target=:y, + outcome=:y, treatment=(T₁=(case=1, control=0),), confounders=[:W₁], covariates=[:C₁] @@ -233,7 +233,7 @@ end n=100 dataset = naive_dataset(;n=n) Ψ = ATE( - target = :y, + outcome = :y, treatment = (T=(case=1, control=0),), confounders = [:W] ) diff --git a/test/composition.jl b/test/composition.jl index 81e0a6ad..3e4f175c 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -21,7 +21,7 @@ function make_dataset(;n=100) end basicCM(val) = CM( - target = :y, + outcome = :y, treatment = (T=val,), confounders = [:W] ) @@ -51,7 +51,7 @@ end # Via ATE ATE₁₀ = ATE( - target = :y, + outcome = :y, treatment = (T=(case=1, control=0),), confounders = [:W] ) diff --git a/test/configuration.jl b/test/configuration.jl index b14a8a70..950e16c3 100644 --- a/test/configuration.jl +++ b/test/configuration.jl @@ -8,18 +8,18 @@ using Serialization param_file = "estimands_sample.yaml" estimands = [ IATE(; - target=:Y1, + outcome=:Y1, treatment=(T2 = (case = 1, control = 0), T1 = (case = 1, control = 0)), confounders=[:W1, :W2], covariates=Symbol[:C1] ), ATE(; - target=:Y3, + outcome=:Y3, treatment=(T3 = (case = 1, control = 0), T1 = (case = "AC", control = "CC")), confounders=[:W1], covariates=Symbol[] ), - CM(;target=:Y3, treatment=(T3 = "AC", T1 = "CC"), confounders=[:W1], covariates=Symbol[]) + CM(;outcome=:Y3, treatment=(T3 = "AC", T1 = "CC"), confounders=[:W1], covariates=Symbol[]) ] # Test YAML Serialization estimands_to_yaml(param_file, estimands) diff --git a/test/double_robustness_ate.jl b/test/double_robustness_ate.jl index d0940d07..46311fdc 100644 --- a/test/double_robustness_ate.jl +++ b/test/double_robustness_ate.jl @@ -17,7 +17,7 @@ include("helper_fns.jl") """ Q and G are two logistic models """ -function binary_target_binary_treatment_pb(;n=100) +function binary_outcome_binary_treatment_pb(;n=100) rng = StableRNG(123) p_w() = 0.3 pa_given_w(w) = 1 ./ (1 .+ exp.(-0.5w .+ 1)) @@ -44,7 +44,7 @@ end From https://www.degruyter.com/document/doi/10.2202/1557-4679.1043/html The theoretical ATE is 1 """ -function continuous_target_binary_treatment_pb(;n=100) +function continuous_outcome_binary_treatment_pb(;n=100) rng = StableRNG(123) Unif = Uniform(0, 1) W = float(rand(rng, Bernoulli(0.5), n, 3)) @@ -56,7 +56,7 @@ function continuous_target_binary_treatment_pb(;n=100) return (T = T, W₁ = W₁, W₂ = W₂, W₃ = W₃, y = y), 4 end -function continuous_target_categorical_treatment_pb(;n=100, control="TT", case="AA") +function continuous_outcome_categorical_treatment_pb(;n=100, control="TT", case="AA") rng = StableRNG(123) ft(T) = (T .== "AA") - (T .== "AT") + 2(T .== "TT") fw(W₁, W₂, W₃) = 2W₁ + 3W₂ - 4W₃ @@ -102,7 +102,7 @@ end cache = TMLECache(dataset) # Test first ATE, only T₁ treatment varies Ψ = ATE( - target = :y, + outcome = :y, treatment = (T₁=(case=1., control=0.), T₂=(case=1., control=1.)), confounders = [:W₁, :W₂] ) @@ -113,7 +113,7 @@ end ) tmle_result, cache = tmle!(cache, Ψ, η_spec, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₁) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # When Q is well specified but G is misspecified @@ -123,18 +123,18 @@ end ) tmle_result, cache = tmle!(cache, η_spec, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₁) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Test second ATE, two treatment varies Ψ = ATE( - target = :y, + outcome = :y, treatment = (T₁=(case=1., control=0.), T₂=(case=1., control=0.)), confounders = [:W₁, :W₂] ) tmle_result, cache = tmle!(cache, Ψ, η_spec, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₀) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # When Q is well specified but G is misspecified @@ -144,14 +144,14 @@ end ) tmle_result, cache = tmle!(cache, η_spec, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₀) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end -@testset "Test Double Robustness ATE on continuous_target_categorical_treatment_pb" begin - dataset, Ψ₀ = continuous_target_categorical_treatment_pb(;n=10_000, control="TT", case="AA") +@testset "Test Double Robustness ATE on continuous_outcome_categorical_treatment_pb" begin + dataset, Ψ₀ = continuous_outcome_categorical_treatment_pb(;n=10_000, control="TT", case="AA") Ψ = ATE( - target = :y, + outcome = :y, treatment = (T=(case="AA", control="TT"),), confounders = [:W₁, :W₂, :W₃] ) @@ -163,7 +163,7 @@ end cache = TMLECache(dataset) tmle_result, cache = tmle!(cache, Ψ, η_spec, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away @@ -177,14 +177,14 @@ end tmle_result, cache = tmle!(cache, η_spec, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache, target_name=:y) + test_fluct_decreases_risk(cache, outcome_name=:y) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end -@testset "Test Double Robustness ATE on binary_target_binary_treatment_pb" begin - dataset, Ψ₀ = binary_target_binary_treatment_pb(;n=10_000) +@testset "Test Double Robustness ATE on binary_outcome_binary_treatment_pb" begin + dataset, Ψ₀ = binary_outcome_binary_treatment_pb(;n=10_000) Ψ = ATE( - target = :y, + outcome = :y, treatment = (T=(case=true, control=false),), confounders = [:W] ) @@ -195,7 +195,7 @@ end ) tmle_result, cache = tmle(Ψ, η_spec, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-6) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away @@ -208,15 +208,15 @@ end ) tmle_result, cache = tmle!(cache, η_spec, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-6) end -@testset "Test Double Robustness ATE on continuous_target_binary_treatment_pb" begin - dataset, Ψ₀ = continuous_target_binary_treatment_pb(n=10_000) +@testset "Test Double Robustness ATE on continuous_outcome_binary_treatment_pb" begin + dataset, Ψ₀ = continuous_outcome_binary_treatment_pb(n=10_000) Ψ = ATE( - target = :y, + outcome = :y, treatment = (T=(case=true, control=false),), confounders = [:W₁, :W₂, :W₃] ) @@ -228,7 +228,7 @@ end tmle_result, cache = tmle(Ψ, η_spec, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away @@ -241,7 +241,7 @@ end ) tmle_result, cache = tmle!(cache, η_spec, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end diff --git a/test/double_robustness_iate.jl b/test/double_robustness_iate.jl index 82a99af5..f89f2695 100644 --- a/test/double_robustness_iate.jl +++ b/test/double_robustness_iate.jl @@ -18,7 +18,7 @@ cont_interacter = InteractionTransformer(order=2) |> LinearRegressor cat_interacter = InteractionTransformer(order=2) |> LogisticClassifier(lambda=1.) -function binary_target_binary_treatment_pb(;n=100) +function binary_outcome_binary_treatment_pb(;n=100) rng = StableRNG(123) μy_fn(W, T₁, T₂) = logistic.(2W[:, 1] .+ 1W[:, 2] .- 2W[:, 3] .- T₁ .+ T₂ .+ 2*T₁ .* T₂) # Sampling W: Bernoulli @@ -63,7 +63,7 @@ function binary_target_binary_treatment_pb(;n=100) end -function binary_target_categorical_treatment_pb(;n=100) +function binary_outcome_categorical_treatment_pb(;n=100) rng = StableRNG(123) function μy_fn(W, T, Hmach) Thot = MLJBase.transform(Hmach, T) @@ -121,7 +121,7 @@ function binary_target_categorical_treatment_pb(;n=100) end -function continuous_target_binary_treatment_pb(;n=100) +function continuous_outcome_binary_treatment_pb(;n=100) rng = StableRNG(123) μy_fn(W, T₁, T₂) = 2W[:, 1] .+ 1W[:, 2] .- 2W[:, 3] .- T₁ .+ T₂ .+ 2*T₁ .* T₂ # Sampling W: Bernoulli @@ -165,10 +165,10 @@ function continuous_target_binary_treatment_pb(;n=100) return (T₁=T₁, T₂=T₂, W₁=W[:, 1], W₂=W[:, 2], W₃=W[:, 3], y=y), IATE end -@testset "Test Double Robustness IATE on binary_target_binary_treatment_pb" begin - dataset, Ψ₀ = binary_target_binary_treatment_pb(n=10_000) +@testset "Test Double Robustness IATE on binary_outcome_binary_treatment_pb" begin + dataset, Ψ₀ = binary_outcome_binary_treatment_pb(n=10_000) Ψ = IATE( - target=:y, + outcome=:y, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), confounders = [:W₁, :W₂, :W₃] ) @@ -179,7 +179,7 @@ end ) tmle_result, cache = tmle(Ψ, η_spec, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-9) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away @@ -193,17 +193,17 @@ end tmle_result, cache = tmle!(cache, η_spec, verbosity=0); test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-7) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away @test tmle_result.initial ≈ -0.0 atol=1e-1 end -@testset "Test Double Robustness IATE on continuous_target_binary_treatment_pb" begin - dataset, Ψ₀ = continuous_target_binary_treatment_pb(n=10_000) +@testset "Test Double Robustness IATE on continuous_outcome_binary_treatment_pb" begin + dataset, Ψ₀ = continuous_outcome_binary_treatment_pb(n=10_000) Ψ = IATE( - target=:y, + outcome=:y, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), confounders = [:W₁, :W₂, :W₃] ) @@ -215,7 +215,7 @@ end tmle_result, cache = tmle(Ψ, η_spec, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away @@ -229,15 +229,15 @@ end tmle_result, cache = tmle!(cache, η_spec, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end -@testset "Test Double Robustness IATE on binary_target_categorical_treatment_pb" begin - dataset, Ψ₀ = binary_target_categorical_treatment_pb(n=10_000) +@testset "Test Double Robustness IATE on binary_outcome_categorical_treatment_pb" begin + dataset, Ψ₀ = binary_outcome_categorical_treatment_pb(n=10_000) Ψ = IATE( - target=:y, + outcome=:y, treatment=(T₁=(case="CC", control="CG"), T₂=(case="AT", control="AA")), confounders = [:W₁, :W₂, :W₃] ) @@ -249,7 +249,7 @@ end tmle_result, cache = tmle(Ψ, η_spec, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away @test tmle_result.initial == 0 @@ -262,7 +262,7 @@ end tmle_result, cache = tmle!(cache, η_spec, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away @test tmle_result.initial ≈ -0.01 atol=1e-2 diff --git a/test/estimands.jl b/test/estimands.jl index 6cc1dd92..8b430928 100644 --- a/test/estimands.jl +++ b/test/estimands.jl @@ -15,10 +15,10 @@ using TMLE Ψ = CM( treatment=(T₁=1,), confounders=[:W₁, :W₂], - target =:y + outcome =:y ) @test TMLE.treatments(Ψ) == [:T₁] - @test TMLE.target(Ψ) == :y + @test TMLE.outcome(Ψ) == :y @test TMLE.treatment_and_confounders(Ψ) == [:W₁, :W₂, :T₁] @test TMLE.confounders_and_covariates(Ψ) == [:W₁, :W₂] @test TMLE.allcolumns(Ψ) == [:W₁, :W₂, :T₁, :y] @@ -26,11 +26,11 @@ using TMLE Ψ = CM( treatment=(T₁=1,), confounders=[:W₁, :W₂], - target =:y, + outcome =:y, covariates=[:C₁] ) @test TMLE.treatments(Ψ) == [:T₁] - @test TMLE.target(Ψ) == :y + @test TMLE.outcome(Ψ) == :y @test TMLE.treatment_and_confounders(Ψ) == [:W₁, :W₂, :T₁] @test TMLE.confounders_and_covariates(Ψ) == [:W₁, :W₂, :C₁] @test TMLE.allcolumns(Ψ) == [:W₁, :W₂, :C₁, :T₁, :y] @@ -39,44 +39,44 @@ end @testset "Test optimize_ordering" begin estimands = [ ATE( - target=:Y, + outcome=:Y, treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), confounders=[:W₁, :W₂] ), ATE( - target=:Y, + outcome=:Y, treatment=(T₁=(case=1, control=0),), confounders=[:W₁] ), ATE( - target=:Y₂, + outcome=:Y₂, treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), confounders=[:W₁, :W₂], ), CM( - target=:Y, + outcome=:Y, treatment=(T₁=0,), confounders=[:W₁], covariates=[:C₁] ), IATE( - target=:Y, + outcome=:Y, treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), confounders=[:W₁, :W₂] ), CM( - target=:Y, + outcome=:Y, treatment=(T₁=0,), confounders=[:W₁] ), ATE( - target=:Y₂, + outcome=:Y₂, treatment=(T₁=(case=1, control=0),), confounders=[:W₁], covariates=[:C₁] ), ATE( - target=:Y₂, + outcome=:Y₂, treatment=(T₁=(case=0, control=1), T₂=(case="AC", control="CC")), confounders=[:W₁, :W₂], covariates=[:C₂] diff --git a/test/fit_nuisance.jl b/test/fit_nuisance.jl index b52b8869..d039584e 100644 --- a/test/fit_nuisance.jl +++ b/test/fit_nuisance.jl @@ -35,7 +35,7 @@ end dataset = make_dataset(;n=50_000, rng=StableRNG(123)) cache = TMLECache(dataset) Ψ = ATE( - target=:Ybin, + outcome=:Ybin, treatment=(T₁=(case=true, control=false),), confounders = [:W₁, :W₂, :W₃] ) @@ -47,15 +47,15 @@ end TMLE.fit_nuisance!(cache, verbosity=0) # Check G fit Gcoefs = fitted_params(cache.η.G).coefs - @test Gcoefs[1][2] ≈ 0.298 atol=0.001 # target is 0.3 - @test Gcoefs[2][2] ≈ -1.441 atol=0.001 # target is -1.4 - @test Gcoefs[3][2] ≈ 1.953 atol=0.001 # target is 1.9 + @test Gcoefs[1][2] ≈ 0.298 atol=0.001 # outcome is 0.3 + @test Gcoefs[2][2] ≈ -1.441 atol=0.001 # outcome is -1.4 + @test Gcoefs[3][2] ≈ 1.953 atol=0.001 # outcome is 1.9 # Check Q fit Qcoefs = fitted_params(cache.η.Q).coefs - @test Qcoefs[1][2] ≈ 2.000 atol=0.001 # target is 2 - @test Qcoefs[2][2] ≈ -4.006 atol=0.001 # target is -4 - @test Qcoefs[3][2] ≈ -3.046 atol=0.001 # target is -3 - @test Qcoefs[4][2] ≈ -2.062 atol=0.001 # target is 2 on T₁=1 but T₁=0 is encoded + @test Qcoefs[1][2] ≈ 2.000 atol=0.001 # outcome is 2 + @test Qcoefs[2][2] ≈ -4.006 atol=0.001 # outcome is -4 + @test Qcoefs[3][2] ≈ -3.046 atol=0.001 # outcome is -3 + @test Qcoefs[4][2] ≈ -2.062 atol=0.001 # outcome is 2 on T₁=1 but T₁=0 is encoded @test fitted_params(cache.η.Q).intercept ≈ 2.053 atol=0.001 # Fluctuate the inital Q, the log_loss should decrease @@ -70,7 +70,7 @@ end dataset = make_dataset(;n=50_000, rng=StableRNG(123)) cache = TMLECache(dataset) Ψ = ATE( - target=:Ycont, + outcome=:Ycont, treatment=(T₁=(case=true, control=false),), confounders = [:W₁, :W₂, :W₃] ) @@ -82,15 +82,15 @@ end TMLE.fit_nuisance!(cache, verbosity=0) # Check G fit Gcoefs = fitted_params(cache.η.G).coefs - @test Gcoefs[1][2] ≈ 0.298 atol=0.001 # target is 0.3 - @test Gcoefs[2][2] ≈ -1.441 atol=0.001 # target is -1.4 - @test Gcoefs[3][2] ≈ 1.953 atol=0.001 # target is 1.9 + @test Gcoefs[1][2] ≈ 0.298 atol=0.001 # outcome is 0.3 + @test Gcoefs[2][2] ≈ -1.441 atol=0.001 # outcome is -1.4 + @test Gcoefs[3][2] ≈ 1.953 atol=0.001 # outcome is 1.9 # Check Q fit Qcoefs = fitted_params(cache.η.Q).coefs - @test Qcoefs[1][2] ≈ 2.001 atol=0.001 # target is 2 - @test Qcoefs[2][2] ≈ -4.002 atol=0.001 # target is -4 - @test Qcoefs[3][2] ≈ -2.999 atol=0.001 # target is -3 - @test Qcoefs[4][2] ≈ -1.988 atol=0.001 # target is 2 on T₁=1 but T₁=0 is encoded + @test Qcoefs[1][2] ≈ 2.001 atol=0.001 # outcome is 2 + @test Qcoefs[2][2] ≈ -4.002 atol=0.001 # outcome is -4 + @test Qcoefs[3][2] ≈ -2.999 atol=0.001 # outcome is -3 + @test Qcoefs[4][2] ≈ -1.988 atol=0.001 # outcome is 2 on T₁=1 but T₁=0 is encoded @test fitted_params(cache.η.Q).intercept ≈ 1.996 atol=0.001 # Fluctuate the inital Q, the RMSE should decrease diff --git a/test/helper_fns.jl b/test/helper_fns.jl index 6f024850..29ab0b15 100644 --- a/test/helper_fns.jl +++ b/test/helper_fns.jl @@ -4,24 +4,24 @@ using MLJBase using CategoricalArrays """ -The risk for continuous targets is the MSE +The risk for continuous outcomes is the MSE """ risk(ŷ, y) = rmse(TMLE.expected_value(ŷ), y) """ -The risk for binary targets is the LogLoss +The risk for binary outcomes is the LogLoss """ risk(ŷ, y::CategoricalArray) = mean(log_loss(ŷ, y)) """ - test_fluct_decreases_risk(cache; target_name::Symbol=nothing) + test_fluct_decreases_risk(cache; outcome_name::Symbol=nothing) The fluctuation is supposed to decrease the risk as its objective function is the risk itself. It seems that sometimes this is not entirely true in practice, so the test actually checks that it does not increase risk more than tol """ -function test_fluct_decreases_risk(cache; target_name::Symbol=nothing, atol=1e-6) - y = cache.data[:no_missing][target_name] +function test_fluct_decreases_risk(cache; outcome_name::Symbol=nothing, atol=1e-6) + y = cache.data[:no_missing][outcome_name] initial_risk = risk(cache.data[:Q₀], y) fluct_risk = risk(cache.data[:Qfluct], y) @test initial_risk >= fluct_risk || isapprox(initial_risk, fluct_risk, atol=atol) diff --git a/test/missing_management.jl b/test/missing_management.jl index 9f052634..4d491a76 100644 --- a/test/missing_management.jl +++ b/test/missing_management.jl @@ -43,13 +43,13 @@ end @testset "Test estimation with missing values and ordered factor treatment" begin dataset = dataset_with_missing_and_ordered_treatment(;n=1000) Ψ = ATE( - target=:y, + outcome=:y, confounders=[:W], treatment=(T=(case=1, control=0),)) η_spec = NuisanceSpec(LinearRegressor(), LogisticClassifier(lambda=0)) tmle_result, cache = tmle(Ψ, η_spec, dataset; verbosity=0) test_coverage(tmle_result, 1) - test_fluct_decreases_risk(cache; target_name=:y) + test_fluct_decreases_risk(cache; outcome_name=:y) end end diff --git a/test/non_regression_test.jl b/test/non_regression_test.jl index 3f3f6021..9e2dacb0 100644 --- a/test/non_regression_test.jl +++ b/test/non_regression_test.jl @@ -17,7 +17,7 @@ using MLJGLMInterface data[!, col] = float(data[!, col]) end Ψ = ATE( - target=:haz01, + outcome=:haz01, treatment = (parity01=(case=1, control=0),), confounders = confounders ) diff --git a/test/utils.jl b/test/utils.jl index 54f51f16..588ca54a 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -15,7 +15,7 @@ using MLJModels function test_params_match(estimands, expected_params) for (param, expected_param) in zip(estimands, expected_params) @test typeof(param) == typeof(expected_param) - @test param.target == expected_param.target + @test param.outcome == expected_param.outcome @test param.treatment == expected_param.treatment @test param.confounders == expected_param.confounders @test param.covariates == expected_param.covariates @@ -41,7 +41,7 @@ end @testset "Test indicator_fns" begin # Conditional Mean Ψ = CM( - target=:y, + outcome=:y, treatment=(T₁="A", T₂=1), confounders=[:W] ) @@ -49,7 +49,7 @@ end @test TMLE.indicator_fns(Ψ, TMLE.joint_name) == Dict{String, Float64}("A_&_1" => 1) # ATE Ψ = ATE( - target=:y, + outcome=:y, treatment=(T₁=(case="A", control="B"), T₂=(control=0, case=1)), confounders=[:W] ) @@ -59,7 +59,7 @@ end ) # 2-points IATE Ψ = IATE( - target=:y, + outcome=:y, treatment=(T₁=(case="A", control="B"), T₂=(case=1, control=0)), confounders=[:W] ) @@ -71,7 +71,7 @@ end ) # 3-points IATE Ψ = IATE( - target=:y, + outcome=:y, treatment=(T₁=(case="A", control="B"), T₂=(case=1, control=0), T₃=(control="D", case="C")), confounders=[:W] ) @@ -167,7 +167,7 @@ end fit!(Gmach, verbosity=0) Ψ = ATE( - target =:y, + outcome =:y, treatment=(t₁=(case="a", control="b"),), confounders = [:x1, :x2, :x3] ) @@ -193,7 +193,7 @@ end jointT) fit!(Gmach, verbosity=0) Ψ = IATE( - target =:y, + outcome =:y, treatment=(t₁=(case=1, control=0), t₂=(case=1, control=0)), confounders = [:x1, :x2, :x3] ) @@ -222,7 +222,7 @@ end jointT) fit!(Gmach, verbosity=0) Ψ = IATE( - target =:y, + outcome =:y, treatment=(t₁=(case="a", control="b"), t₂=(case=1, control=2), t₃=(case=true, control=false)), diff --git a/test/warm_restart.jl b/test/warm_restart.jl index 4436a5f9..89b833e0 100644 --- a/test/warm_restart.jl +++ b/test/warm_restart.jl @@ -30,7 +30,7 @@ function build_dataset(;n=100) # Treatment | Confounders T₁ = rand(rng, Uniform(), n) .< logistic.(0.5sin.(W₁) .- 1.5W₂) T₂ = rand(rng, Uniform(), n) .< logistic.(-3W₁ - 1.5W₂) - # target | Confounders, Covariates, Treatments + # outcome | Confounders, Covariates, Treatments y₁ = 1 .+ 2W₁ .+ 3W₂ .- 4C₁.*T₁ .+ T₁ + T₂.*W₂ .+ rand(rng, Normal(0, 0.1), n) y₂ = 1 .+ 10T₂ .+ rand(rng, Normal(0, 0.1), n) y₃ = 1 .+ 2W₁ .+ 3W₂ .- 4C₁.*T₁ .- 2T₂.*T₁.*W₂ .+ rand(rng, Normal(0, 0.1), n) @@ -52,7 +52,7 @@ table_types = (Tables.columntable, DataFrame) dataset = tt(build_dataset(;n=50_000)) # Define the estimand of interest Ψ = ATE( - target=:y₁, + outcome=:y₁, treatment=(T₁=(case=true, control=false),), confounders=[:W₁, :W₂], covariates=[:C₁] @@ -76,13 +76,13 @@ table_types = (Tables.columntable, DataFrame) Ψ₀ = -1 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₁) + test_fluct_decreases_risk(cache; outcome_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Update the treatment specification # Nuisance estimands should not fitted again Ψ = ATE( - target=:y₁, + outcome=:y₁, treatment=(T₁=(case=false, control=true),), confounders=[:W₁, :W₂], covariates=[:C₁] @@ -99,12 +99,12 @@ table_types = (Tables.columntable, DataFrame) Ψ₀ = 1 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₁) + test_fluct_decreases_risk(cache; outcome_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Remove the covariate variable, this will trigger the refit of Q Ψ = ATE( - target=:y₁, + outcome=:y₁, treatment=(T₁=(case=true, control=false),), confounders=[:W₁, :W₂], ) @@ -120,13 +120,13 @@ table_types = (Tables.columntable, DataFrame) Ψ₀ = -1 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₁) + test_fluct_decreases_risk(cache; outcome_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Change the treatment # This will trigger the refit of all η Ψ = ATE( - target=:y₁, + outcome=:y₁, treatment=(T₂=(case=true, control=false),), confounders=[:W₁, :W₂], ) @@ -142,7 +142,7 @@ table_types = (Tables.columntable, DataFrame) Ψ₀ = 0.5 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₁) + test_fluct_decreases_risk(cache; outcome_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Remove a confounding variable @@ -150,7 +150,7 @@ table_types = (Tables.columntable, DataFrame) # Since we are not accounting for all confounders # we can't have ground truth coverage on this setting Ψ = ATE( - target=:y₁, + outcome=:y₁, treatment=(T₂=(case=true, control=false),), confounders=[:W₁], ) @@ -166,10 +166,10 @@ table_types = (Tables.columntable, DataFrame) @test tmle_result.estimand == Ψ test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - # Change the target + # Change the outcome # This will trigger the refit of Q only Ψ = ATE( - target=:y₂, + outcome=:y₂, treatment=(T₂=(case=true, control=false),), confounders=[:W₁], ) @@ -185,13 +185,13 @@ table_types = (Tables.columntable, DataFrame) Ψ₀ = 10 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₂) + test_fluct_decreases_risk(cache; outcome_name=:y₂) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Adding a covariate # This will trigger the refit of Q only Ψ = ATE( - target=:y₂, + outcome=:y₂, treatment=(T₂=(case=true, control=false),), confounders=[:W₁], covariates=[:C₁] @@ -199,7 +199,7 @@ table_types = (Tables.columntable, DataFrame) tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₂) + test_fluct_decreases_risk(cache; outcome_name=:y₂) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end @@ -207,7 +207,7 @@ end dataset = tt(build_dataset(;n=50_000)) # Define the estimand of interest Ψ = ATE( - target=:y₁, + outcome=:y₁, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), confounders=[:W₁, :W₂], covariates=[:C₁] @@ -221,12 +221,12 @@ end Ψ₀ = -0.5 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₁) + test_fluct_decreases_risk(cache; outcome_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Let's switch case and control for T₂ Ψ = ATE( - target=:y₁, + outcome=:y₁, treatment=(T₁=(case=true, control=false), T₂=(case=false, control=true)), confounders=[:W₁, :W₂], covariates=[:C₁] @@ -243,14 +243,14 @@ end Ψ₀ = -1.5 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₁) + test_fluct_decreases_risk(cache; outcome_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end @testset "Test Warm restart: CM, $tt" for tt in table_types dataset = tt(build_dataset(;n=50_000)) Ψ = CM( - target=:y₁, + outcome=:y₁, treatment=(T₁=true, T₂=true), confounders=[:W₁, :W₂], covariates=[:C₁] @@ -263,12 +263,12 @@ end Ψ₀ = 3 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₁) + test_fluct_decreases_risk(cache; outcome_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Let's switch case and control for T₂ Ψ = CM( - target=:y₁, + outcome=:y₁, treatment=(T₁=true, T₂=false), confounders=[:W₁, :W₂], covariates=[:C₁] @@ -285,12 +285,12 @@ end Ψ₀ = 2.5 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₁) + test_fluct_decreases_risk(cache; outcome_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - # Change the target + # Change the outcome Ψ = CM( - target=:y₂, + outcome=:y₂, treatment=(T₁=true, T₂=false), confounders=[:W₁, :W₂], covariates=[:C₁] @@ -307,12 +307,12 @@ end Ψ₀ = 1 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₂) + test_fluct_decreases_risk(cache; outcome_name=:y₂) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Change the treatment Ψ = CM( - target=:y₂, + outcome=:y₂, treatment=(T₂=true,), confounders=[:W₁, :W₂], covariates=[:C₁] @@ -329,14 +329,14 @@ end Ψ₀ = 11 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₂) + test_fluct_decreases_risk(cache; outcome_name=:y₂) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end @testset "Test Warm restart: pairwise IATE, $tt" for tt in table_types dataset = tt(build_dataset(;n=50_000)) Ψ = IATE( - target=:y₃, + outcome=:y₃, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), confounders=[:W₁, :W₂], covariates=[:C₁] @@ -349,12 +349,12 @@ end Ψ₀ = -1 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₃) + test_fluct_decreases_risk(cache; outcome_name=:y₃) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Remove covariate from fit Ψ = IATE( - target=:y₃, + outcome=:y₃, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), confounders=[:W₁, :W₂], ) @@ -373,7 +373,7 @@ end # Changing the treatments values Ψ = IATE( - target=:y₃, + outcome=:y₃, treatment=(T₁=(case=false, control=true), T₂=(case=true, control=false)), confounders=[:W₁, :W₂], ) @@ -389,7 +389,7 @@ end Ψ₀ = - Ψ₀ @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₃) + test_fluct_decreases_risk(cache; outcome_name=:y₃) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end @@ -398,7 +398,7 @@ end dataset = build_dataset(;n=50_000) # Define the estimand of interest Ψ = ATE( - target=:y₁, + outcome=:y₁, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), confounders=[:W₁, :W₂], covariates=[:C₁] @@ -412,11 +412,11 @@ end Ψ₀ = -0.5 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₁) + test_fluct_decreases_risk(cache; outcome_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) Ψnew = ATE( - target=:y₁, + outcome=:y₁, treatment=(T₁=(case=false, control=true), T₂=(case=false, control=true)), confounders=[:W₁, :W₂], covariates=[:C₁] @@ -438,7 +438,7 @@ end Ψ₀ = 0.5 @test tmle_result.estimand == Ψnew test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; target_name=:y₁) + test_fluct_decreases_risk(cache; outcome_name=:y₁) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end From 20355f469ef758a4439a97dde8726df2a11db686 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 2 Aug 2023 16:31:31 +0100 Subject: [PATCH 004/151] add scm --- src/TMLE.jl | 4 +- src/scm.jl | 104 ++++++++++++++++++++++++++++++++++++++++++++++++++++ test/scm.jl | 53 ++++++++++++++++++++++++++ 3 files changed, 160 insertions(+), 1 deletion(-) create mode 100644 src/scm.jl create mode 100644 test/scm.jl diff --git a/src/TMLE.jl b/src/TMLE.jl index 29a66ddc..5f28788c 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -24,6 +24,7 @@ import AbstractDifferentiation as AD # EXPORTS # ############################################################################# +export StructuralCausalModel, SCM, StructuralEquation, SE, StaticConfoundedModel export NuisanceSpec, TMLECache, update!, CM, ATE, IATE export tmle, tmle! export var, estimate, initial_estimate, OneSampleTTest, OneSampleZTest, pvalue, confint @@ -34,6 +35,7 @@ export estimands_from_yaml, estimands_to_yaml, optimize_ordering, optimize_order # INCLUDES # ############################################################################# +include("scm.jl") include("treatment_transformer.jl") include("estimands.jl") include("utils.jl") @@ -125,6 +127,6 @@ function run_precompile_workload() end -run_precompile_workload() +# run_precompile_workload() end diff --git a/src/scm.jl b/src/scm.jl new file mode 100644 index 00000000..cce6a18f --- /dev/null +++ b/src/scm.jl @@ -0,0 +1,104 @@ +##################################################################### +### Structural Equation ### +##################################################################### + +# const AVAILABLE_SPECIFICATIONS = Set([:density, :mean]) + +SelfReferringEquationError(outcome) = + ArgumentError(string("Variable ", outcome, " appears on both sides of the equation.")) + +# UnavailableSpecificationError(key) = +# ArgumentError(string("Unavailable specification: ", key, ". Available specifications are: ", join(AVAILABLE_SPECIFICATIONS, ", "), ".")) + +# SpecificationShouldBeModelError(key) = +# ArgumentError(string("Specification ", key, " should be a MLJBase.Model.")) + +struct StructuralEquation + outcome::Symbol + parents::Vector{Symbol} + model::Model + function StructuralEquation(outcome, parents, model) + outcome ∉ parents || throw(SelfReferringEquationError(outcome)) + # for (specname, model) ∈ zip(keys(specifications), specifications) + # specname ∈ AVAILABLE_SPECIFICATIONS || throw(UnavailableSpecificationError(specname)) + # model isa Model || throw(SpecificationShouldBeModelError(specname)) + # end + if model isa Nothing + return new(outcome, parents) + else + return new(outcome, parents, model) + end + end +end + +StructuralEquation(outcome, parents; model=nothing) = StructuralEquation(outcome, parents, model) + +const SE = StructuralEquation + +assign_model!(eq::SE, model::Nothing) = nothing +assign_model!(eq::SE, model::Model) = eq.model = model + +outcome(se::SE) = se.outcome +parents(se::SE) = se.parents + +##################################################################### +### Structural Causal Model ### +##################################################################### + +AlreadyAssignedError(key) = ArgumentError(string("Variable ", key, " is already assigned in the SCM.")) + +struct StructuralCausalModel + equations::Dict{Symbol, StructuralEquation} +end + +const SCM = StructuralCausalModel + +StructuralCausalModel(equations::Vararg{SE}) = + StructuralCausalModel(Dict(outcome(eq) => eq for eq in equations)) + +function Base.push!(scm::SCM, eq::SE) + key = outcome(eq) + scm[key] = eq +end + +function Base.setindex!(scm::SCM, eq::SE, key::Symbol) + if haskey(scm.equations, key) + throw(AlreadyAssignedError(key)) + end + scm.equations[key] = eq +end + +Base.getindex(scm::SCM, key::Symbol) = scm.equations[key] + +function Base.getproperty(scm::StructuralCausalModel, key::Symbol) + hasfield(StructuralCausalModel, key) && return getfield(scm, key) + return scm.equations[key] +end + +parents(scm::StructuralCausalModel, key::Symbol) = parents(getproperty(scm, key)) + +##################################################################### +### StaticConfoundedModel ### +##################################################################### + +vcat_covariates(covariates::Nothing, treatment, confounders) = vcat(treatment, confounders) +vcat_covariates(covariates, treatment, confounders) = vcat(covariates, treatment, confounders) + +function StaticConfoundedModel( + outcome::Symbol, treatment::Symbol, confounders::Union{Symbol, AbstractVector{Symbol}}; + covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, + outcome_spec = LinearRegressor(), + treatment_spec = LinearBinaryClassifier() + ) + Yeq = StructuralEquation( + outcome, + vcat_covariates(covariates, treatment, confounders), + outcome_spec + ) + Teq = StructuralEquation( + treatment, + vcat(confounders), + treatment_spec + ) + return StructuralCausalModel(Yeq, Teq) +end \ No newline at end of file diff --git a/test/scm.jl b/test/scm.jl new file mode 100644 index 00000000..1f3014fc --- /dev/null +++ b/test/scm.jl @@ -0,0 +1,53 @@ +module TestSCM + +using Test +using TMLE +using MLJGLMInterface + +@testset "Test SE" begin + # Incorrect Structural Equations + @test_throws TMLE.SelfReferringEquationError(:Y) SE(:Y, [:Y]) + # Correct Structural Equations + @test SE(:Y, [:T], LinearBinaryClassifier()) isa SE +end + +@testset "Test SCM" begin + scm = SCM() + # setindex! + Yeq = SE(:Y, [:T, :W, :C]) + scm[:Y] = Yeq + # push! + Teq = SE(:T, [:W]) + push!(scm, Teq) + # getindex + @test scm[:Y] === Yeq + # getproperty + @test scm.T === Teq + # Already defined equation + @test_throws TMLE.AlreadyAssignedError(:Y) push!(scm, SE(:Y, [:T])) +end + +@testset "Test StaticConfoundedModel" begin + # With covariates + scm = StaticConfoundedModel(:Y, :T, [:W₁, :W₂], covariates=[:C₁]) + @test TMLE.parents(scm, :Y) == [:C₁, :T, :W₁, :W₂] + @test TMLE.parents(scm, :T) == [:W₁, :W₂] + @test scm.Y.model isa LinearRegressor + @test scm.T.model isa LinearBinaryClassifier + @test scm.T.model.fit_intercept === true + + # Without covariates + scm = StaticConfoundedModel(:Y, :T, :W₁, outcome_spec=LinearBinaryClassifier(), treatment_spec=LinearBinaryClassifier(fit_intercept=false)) + @test TMLE.parents(scm, :Y) == [:T, :W₁] + @test TMLE.parents(scm, :T) == [:W₁] + @test scm.Y.model isa LinearBinaryClassifier + @test scm.T.model.fit_intercept === false + + # With 1 covariate + scm = StaticConfoundedModel(:Y, :T, :W₁, covariates=:C₁) + @test TMLE.parents(scm, :Y) == [:C₁, :T, :W₁] + @test TMLE.parents(scm, :T) == [:W₁] +end + +end +true \ No newline at end of file From ef4485193532412af36f928618ca878b0dcefbc4 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 3 Aug 2023 15:40:23 +0100 Subject: [PATCH 005/151] add show methods to scm and se --- src/scm.jl | 24 ++++++++++++++++++++++++ test/scm.jl | 10 +++++++++- 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/src/scm.jl b/src/scm.jl index cce6a18f..6bffeabe 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -35,6 +35,14 @@ StructuralEquation(outcome, parents; model=nothing) = StructuralEquation(outcome const SE = StructuralEquation +function string(eq::SE; subscript="") + eq_string = string(eq.outcome, " = f", subscript, "(", join(eq.parents, ", "), ")") + model_string = isdefined(eq, :model) ? string(", ", typeof(eq.model)) : "" + return string(eq_string, model_string) +end + +Base.show(io::IO, eq::SE) = println(io, string(eq)) + assign_model!(eq::SE, model::Nothing) = nothing assign_model!(eq::SE, model::Model) = eq.model = model @@ -56,6 +64,22 @@ const SCM = StructuralCausalModel StructuralCausalModel(equations::Vararg{SE}) = StructuralCausalModel(Dict(outcome(eq) => eq for eq in equations)) +function string(scm::SCM) + scm_string = """ + Structural Causal Model: + ----------------------- + """ + for (index, (_, eq)) in enumerate(scm.equations) + digits = split(string(index), "") + subscript = join(Meta.parse("'\\U0208$digit'") for digit in digits) + scm_string = string(scm_string, string(eq;subscript=subscript), "\n") + end + return scm_string +end + +Base.show(io::IO, scm::SCM) = println(io, string(scm)) + + function Base.push!(scm::SCM, eq::SE) key = outcome(eq) scm[key] = eq diff --git a/test/scm.jl b/test/scm.jl index 1f3014fc..c72bfe86 100644 --- a/test/scm.jl +++ b/test/scm.jl @@ -8,7 +8,13 @@ using MLJGLMInterface # Incorrect Structural Equations @test_throws TMLE.SelfReferringEquationError(:Y) SE(:Y, [:Y]) # Correct Structural Equations - @test SE(:Y, [:T], LinearBinaryClassifier()) isa SE + eq = SE(:Y, [:T, :W], LinearBinaryClassifier()) + @test eq isa SE + @test string(eq) == "Y = f(T, W), LinearBinaryClassifier" + + eq = SE(:Y, [:T, :W]) + @test !isdefined(eq, :model) + @test string(eq) == "Y = f(T, W)" end @testset "Test SCM" begin @@ -25,6 +31,8 @@ end @test scm.T === Teq # Already defined equation @test_throws TMLE.AlreadyAssignedError(:Y) push!(scm, SE(:Y, [:T])) + # string representation + @test string(scm) == "Structural Causal Model:\n-----------------------\nT = f₁(W)\nY = f₂(T, W, C)\n" end @testset "Test StaticConfoundedModel" begin From 96586f08763415dee5820d27059531740e5ab6a9 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 3 Aug 2023 17:12:33 +0100 Subject: [PATCH 006/151] update equation struct to be fitted later --- src/scm.jl | 31 +++++++++++++------------------ test/scm.jl | 7 +++---- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/src/scm.jl b/src/scm.jl index 6bffeabe..cf3b5afd 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -13,21 +13,14 @@ SelfReferringEquationError(outcome) = # SpecificationShouldBeModelError(key) = # ArgumentError(string("Specification ", key, " should be a MLJBase.Model.")) -struct StructuralEquation +mutable struct StructuralEquation outcome::Symbol parents::Vector{Symbol} - model::Model + model::Union{Model, Nothing} + mach::Union{Machine, Nothing} function StructuralEquation(outcome, parents, model) outcome ∉ parents || throw(SelfReferringEquationError(outcome)) - # for (specname, model) ∈ zip(keys(specifications), specifications) - # specname ∈ AVAILABLE_SPECIFICATIONS || throw(UnavailableSpecificationError(specname)) - # model isa Model || throw(SpecificationShouldBeModelError(specname)) - # end - if model isa Nothing - return new(outcome, parents) - else - return new(outcome, parents, model) - end + return new(outcome, parents, model, nothing) end end @@ -35,13 +28,15 @@ StructuralEquation(outcome, parents; model=nothing) = StructuralEquation(outcome const SE = StructuralEquation -function string(eq::SE; subscript="") +function string_repr(eq::SE; subscript="") eq_string = string(eq.outcome, " = f", subscript, "(", join(eq.parents, ", "), ")") - model_string = isdefined(eq, :model) ? string(", ", typeof(eq.model)) : "" - return string(eq_string, model_string) + if eq.model !== nothing + eq_string = string(eq_string, ", ", typeof(eq.model), ", fitted=", eq.mach !== nothing) + end + return eq_string end -Base.show(io::IO, eq::SE) = println(io, string(eq)) +Base.show(io::IO, eq::SE) = println(io, string_repr(eq)) assign_model!(eq::SE, model::Nothing) = nothing assign_model!(eq::SE, model::Model) = eq.model = model @@ -64,7 +59,7 @@ const SCM = StructuralCausalModel StructuralCausalModel(equations::Vararg{SE}) = StructuralCausalModel(Dict(outcome(eq) => eq for eq in equations)) -function string(scm::SCM) +function string_repr(scm::SCM) scm_string = """ Structural Causal Model: ----------------------- @@ -72,12 +67,12 @@ function string(scm::SCM) for (index, (_, eq)) in enumerate(scm.equations) digits = split(string(index), "") subscript = join(Meta.parse("'\\U0208$digit'") for digit in digits) - scm_string = string(scm_string, string(eq;subscript=subscript), "\n") + scm_string = string(scm_string, string_repr(eq;subscript=subscript), "\n") end return scm_string end -Base.show(io::IO, scm::SCM) = println(io, string(scm)) +Base.show(io::IO, scm::SCM) = println(io, string_repr(scm)) function Base.push!(scm::SCM, eq::SE) diff --git a/test/scm.jl b/test/scm.jl index c72bfe86..7f38cb54 100644 --- a/test/scm.jl +++ b/test/scm.jl @@ -10,11 +10,10 @@ using MLJGLMInterface # Correct Structural Equations eq = SE(:Y, [:T, :W], LinearBinaryClassifier()) @test eq isa SE - @test string(eq) == "Y = f(T, W), LinearBinaryClassifier" + @test TMLE.string_repr(eq) == "Y = f(T, W), LinearBinaryClassifier, fitted=false" eq = SE(:Y, [:T, :W]) - @test !isdefined(eq, :model) - @test string(eq) == "Y = f(T, W)" + @test TMLE.string_repr(eq) == "Y = f(T, W)" end @testset "Test SCM" begin @@ -32,7 +31,7 @@ end # Already defined equation @test_throws TMLE.AlreadyAssignedError(:Y) push!(scm, SE(:Y, [:T])) # string representation - @test string(scm) == "Structural Causal Model:\n-----------------------\nT = f₁(W)\nY = f₂(T, W, C)\n" + @test TMLE.string_repr(scm) == "Structural Causal Model:\n-----------------------\nT = f₁(W)\nY = f₂(T, W, C)\n" end @testset "Test StaticConfoundedModel" begin From 37d3ec6af56ef3ee8da13cb49fd5b0561b844f27 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 3 Aug 2023 20:29:54 +0100 Subject: [PATCH 007/151] add se,scm fit! --- src/TMLE.jl | 7 +++++-- src/scm.jl | 28 ++++++++++++++++++++++++++- test/scm.jl | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 3 deletions(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index 5f28788c..8878bb0c 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -24,12 +24,15 @@ import AbstractDifferentiation as AD # EXPORTS # ############################################################################# -export StructuralCausalModel, SCM, StructuralEquation, SE, StaticConfoundedModel -export NuisanceSpec, TMLECache, update!, CM, ATE, IATE +export SE, StructuralEquation, outcome, treatments +export StructuralCausalModel, SCM, StaticConfoundedModel, equations +export fit! +export NuisanceSpec, TMLECache, update!, CM, ATE, IATE, isidentified export tmle, tmle! export var, estimate, initial_estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose export estimands_from_yaml, estimands_to_yaml, optimize_ordering, optimize_ordering! +export TreatmentTransformer # ############################################################################# # INCLUDES diff --git a/src/scm.jl b/src/scm.jl index cf3b5afd..e10c426e 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -28,10 +28,27 @@ StructuralEquation(outcome, parents; model=nothing) = StructuralEquation(outcome const SE = StructuralEquation +function MLJBase.fit!(eq::SE, dataset; verbosity=1, cache=true, force=false) + if eq.mach === nothing || force === true + verbosity >= 0 && @info(string("Fitting Structural Equation corresponding to variable ", outcome(eq), ".")) + X = selectcols(dataset, parents(eq)) + y = Tables.getcolumn(dataset, outcome(eq)) + mach = machine(eq.model, X, y, cache=cache) + MLJBase.fit!(mach, verbosity=verbosity-1) + eq.mach = mach + else + verbosity >= 0 && @info(string( + "Structural Equation corresponding to variable ", + outcome(eq), + " already fitted, skipping. Set `force=true` to force refit." + )) + end +end + function string_repr(eq::SE; subscript="") eq_string = string(eq.outcome, " = f", subscript, "(", join(eq.parents, ", "), ")") if eq.model !== nothing - eq_string = string(eq_string, ", ", typeof(eq.model), ", fitted=", eq.mach !== nothing) + eq_string = string(eq_string, ", ", nameof(typeof(eq.model)), ", fitted=", eq.mach !== nothing) end return eq_string end @@ -59,6 +76,9 @@ const SCM = StructuralCausalModel StructuralCausalModel(equations::Vararg{SE}) = StructuralCausalModel(Dict(outcome(eq) => eq for eq in equations)) + +equations(scm::SCM) = scm.equations + function string_repr(scm::SCM) scm_string = """ Structural Causal Model: @@ -96,6 +116,12 @@ end parents(scm::StructuralCausalModel, key::Symbol) = parents(getproperty(scm, key)) +function MLJBase.fit!(scm::SCM, dataset; verbosity=1, cache=true, force=false) + for (_, eq) in equations(scm) + fit!(eq, dataset; verbosity=verbosity, cache=cache, force=force) + end +end + ##################################################################### ### StaticConfoundedModel ### ##################################################################### diff --git a/test/scm.jl b/test/scm.jl index 7f38cb54..af750776 100644 --- a/test/scm.jl +++ b/test/scm.jl @@ -3,6 +3,10 @@ module TestSCM using Test using TMLE using MLJGLMInterface +using CategoricalArrays +using Random +using StableRNGs +using MLJBase @testset "Test SE" begin # Incorrect Structural Equations @@ -56,5 +60,55 @@ end @test TMLE.parents(scm, :T) == [:W₁] end +@testset "Test fit!" begin + rng = StableRNG(123) + n = 100 + dataset = ( + Ycat = categorical(rand(rng, [0, 1], n)), + Ycont = rand(rng, n), + T₁ = categorical(rand(rng, [0, 1], n)), + T₂ = categorical(rand(rng, [0, 1], n)), + W₁₁ = rand(rng, n), + W₁₂ = rand(rng, n), + W₂₁ = rand(rng, n), + W₂₂ = rand(rng, n), + C = rand(rng, n) + ) + + scm = SCM( + SE(:Ycat, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], TreatmentTransformer() |> LinearBinaryClassifier()), + SE(:Ycont, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], TreatmentTransformer() |> LinearRegressor()), + SE(:T₁, [:W₁₁, :W₁₂], LinearBinaryClassifier()), + SE(:T₂, [:W₂₁, :W₂₂], LinearBinaryClassifier()), + ) + + # Fits all equations in SCM + fit_log_sequence = ( + (:info, "Fitting Structural Equation corresponding to variable Ycont."), + (:info, "Fitting Structural Equation corresponding to variable Ycat."), + (:info, "Fitting Structural Equation corresponding to variable T₁."), + (:info, "Fitting Structural Equation corresponding to variable T₂."), + ) + @test_logs fit_log_sequence... fit!(scm, dataset, verbosity = 1) + for (key, eq) in equations(scm) + @test eq.mach isa Machine + @test isdefined(eq.mach, :data) + end + # Refit will not do anything + nofit_log_sequence = ( + (:info, "Structural Equation corresponding to variable Ycont already fitted, skipping. Set `force=true` to force refit."), + (:info, "Structural Equation corresponding to variable Ycat already fitted, skipping. Set `force=true` to force refit."), + (:info, "Structural Equation corresponding to variable T₁ already fitted, skipping. Set `force=true` to force refit."), + (:info, "Structural Equation corresponding to variable T₂ already fitted, skipping. Set `force=true` to force refit."), + ) + @test_logs nofit_log_sequence... fit!(scm, dataset, verbosity = 1) + # Force refit and set cache to false + @test_logs fit_log_sequence... fit!(scm, dataset, verbosity = 1, force=true, cache=false) + for (key, eq) in equations(scm) + @test eq.mach isa Machine + @test !isdefined(eq.mach, :data) + end +end + end true \ No newline at end of file From 35fe73df74879e7696c0506b0a86e1fbd37a889c Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 3 Aug 2023 20:54:06 +0100 Subject: [PATCH 008/151] add fit estimands and reset scm --- src/TMLE.jl | 2 +- src/estimands.jl | 73 +++++++++++++++++---- src/scm.jl | 19 +++--- test/estimands.jl | 163 ++++++++++++++++++++++++++++++++++++++-------- test/scm.jl | 6 ++ 5 files changed, 214 insertions(+), 49 deletions(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index 8878bb0c..bb115e95 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -26,7 +26,7 @@ import AbstractDifferentiation as AD export SE, StructuralEquation, outcome, treatments export StructuralCausalModel, SCM, StaticConfoundedModel, equations -export fit! +export fit!, reset! export NuisanceSpec, TMLECache, update!, CM, ATE, IATE, isidentified export tmle, tmle! export var, estimate, initial_estimate, OneSampleTTest, OneSampleZTest, pvalue, confint diff --git a/src/estimands.jl b/src/estimands.jl index 54978aaa..5c431010 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -13,13 +13,19 @@ const causal_graph = """ """ ##################################################################### -### Abstract Estimand ### +### Abstract Estimand ### ##################################################################### """ A Estimand is a functional on distribution space Ψ: ℳ → ℜ. """ abstract type Estimand end +function MLJBase.fit!(Ψ::Estimand, dataset; verbosity=1, cache=true, force=false) + for eq in equations_to_fit(Ψ) + fit!(eq, dataset, verbosity=verbosity, cache=cache, force=force) + end +end + ##################################################################### ### Conditional Mean ### ##################################################################### @@ -58,10 +64,9 @@ CM₂ = CM( ``` """ @option struct CM <: Estimand + scm::StructuralCausalModel outcome::Symbol treatment::NamedTuple - confounders::Vector{Symbol} - covariates::Vector{Symbol} = Symbol[] end ##################################################################### @@ -102,13 +107,11 @@ ATE₂ = ATE( ``` """ @option struct ATE <: Estimand + scm::StructuralCausalModel outcome::Symbol treatment::NamedTuple - confounders::Vector{Symbol} - covariates::Vector{Symbol} = Symbol[] end - ##################################################################### ### Interaction Average Treatment Effect ### ##################################################################### @@ -140,10 +143,9 @@ IATE₁ = IATE( ``` """ @option struct IATE <: Estimand + scm::StructuralCausalModel outcome::Symbol treatment::NamedTuple - confounders::Vector{Symbol} - covariates::Vector{Symbol} = Symbol[] end @@ -204,15 +206,62 @@ NuisanceSpec(Q, G; H=TreatmentTransformer(), F=F_model(target_scitype(Q)), cache ### Methods ### ##################################################################### +CMCompositeEstimand = Union{CM, ATE, IATE} + +Base.show(io::IO, Ψ::T) where T <: CMCompositeEstimand = println(io, T) + +equations_to_fit(Ψ::CMCompositeEstimand) = (Ψ.scm[outcome(Ψ)], (Ψ.scm[t] for t in treatments(Ψ))...) + +variable_not_indataset(role, variable) = string(role, " variable: ", variable, " is not in the dataset.") + +function isidentified(Ψ::CMCompositeEstimand, dataset) + reasons = String[] + sch = Tables.schema(dataset) + # Check outcome variable + outcome(Ψ) ∈ sch.names || push!( + reasons, + variable_not_indataset("Outcome", outcome(Ψ)) + ) + + # Check Treatment variable + for treatment in treatments(Ψ) + treatment ∈ sch.names || push!( + reasons, + variable_not_indataset("Treatment", treatment) + ) + end + + # Check adjustment set + for confounder in Set(vcat(confounders(Ψ)...)) + confounder ∈ sch.names || push!( + reasons, + variable_not_indataset("Confounding", confounder) + ) + end + + return length(reasons) == 0, reasons +end + +outcome(Ψ::CMCompositeEstimand) = Ψ.outcome + selectcols(data, cols) = data |> TableOperations.select(cols...) |> Tables.columntable -confounders(Ψ::Estimand) = Ψ.confounders +function confounders(Ψ::CMCompositeEstimand) + confounders_ = [] + treatments_ = Tuple(treatments(Ψ)) + for treatment in treatments_ + push!( + confounders_, + intersect(parents(Ψ.scm, outcome(Ψ)), parents(Ψ.scm, treatment)) + ) + end + return NamedTuple{treatments_}(confounders_) +end + confounders(dataset, Ψ) = selectcols(dataset, confounders(Ψ)) -covariates(Ψ::Estimand) = Ψ.covariates -covariates(dataset, Ψ) = selectcols(dataset, covariates(Ψ)) -treatments(Ψ::Estimand) = collect(keys(Ψ.treatment)) +treatments(Ψ::CMCompositeEstimand) = collect(keys(Ψ.treatment)) treatments(dataset, Ψ) = selectcols(dataset, treatments(Ψ)) outcome(Ψ::Estimand) = Ψ.outcome diff --git a/src/scm.jl b/src/scm.jl index e10c426e..b244ddf0 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -2,17 +2,9 @@ ### Structural Equation ### ##################################################################### -# const AVAILABLE_SPECIFICATIONS = Set([:density, :mean]) - SelfReferringEquationError(outcome) = ArgumentError(string("Variable ", outcome, " appears on both sides of the equation.")) -# UnavailableSpecificationError(key) = -# ArgumentError(string("Unavailable specification: ", key, ". Available specifications are: ", join(AVAILABLE_SPECIFICATIONS, ", "), ".")) - -# SpecificationShouldBeModelError(key) = -# ArgumentError(string("Specification ", key, " should be a MLJBase.Model.")) - mutable struct StructuralEquation outcome::Symbol parents::Vector{Symbol} @@ -30,7 +22,7 @@ const SE = StructuralEquation function MLJBase.fit!(eq::SE, dataset; verbosity=1, cache=true, force=false) if eq.mach === nothing || force === true - verbosity >= 0 && @info(string("Fitting Structural Equation corresponding to variable ", outcome(eq), ".")) + verbosity >= 1 && @info(string("Fitting Structural Equation corresponding to variable ", outcome(eq), ".")) X = selectcols(dataset, parents(eq)) y = Tables.getcolumn(dataset, outcome(eq)) mach = machine(eq.model, X, y, cache=cache) @@ -45,6 +37,8 @@ function MLJBase.fit!(eq::SE, dataset; verbosity=1, cache=true, force=false) end end +reset!(eq::SE) = eq.mach = nothing + function string_repr(eq::SE; subscript="") eq_string = string(eq.outcome, " = f", subscript, "(", join(eq.parents, ", "), ")") if eq.model !== nothing @@ -117,11 +111,16 @@ end parents(scm::StructuralCausalModel, key::Symbol) = parents(getproperty(scm, key)) function MLJBase.fit!(scm::SCM, dataset; verbosity=1, cache=true, force=false) - for (_, eq) in equations(scm) + for eq in values(equations(scm)) fit!(eq, dataset; verbosity=verbosity, cache=cache, force=force) end end +function reset!(scm::SCM) + for eq in values(equations(scm)) + reset!(eq) + end +end ##################################################################### ### StaticConfoundedModel ### ##################################################################### diff --git a/test/estimands.jl b/test/estimands.jl index 8b430928..478fc377 100644 --- a/test/estimands.jl +++ b/test/estimands.jl @@ -2,38 +2,149 @@ module TestEstimands using Test using TMLE +using CategoricalArrays +using Random +using StableRNGs +using MLJBase +using MLJGLMInterface -@testset "Test variables accessors" begin - n = 10 +@testset "Test methods" begin + n = 5 dataset = ( - W₁=rand(n), - W₂=rand(n), - C₁=rand(n), - T₁=rand([1, 0], n), - y=rand(n), - ) + Y = rand(n), + T = categorical([0, 1, 1, 0, 0]), + W₁ = rand(n), + W₂ = rand(n) + ) + # CM + scm = StaticConfoundedModel( + :Y, :T, [:W₁, :W₂], + covariates=:C₁ + ) Ψ = CM( - treatment=(T₁=1,), - confounders=[:W₁, :W₂], - outcome =:y - ) - @test TMLE.treatments(Ψ) == [:T₁] - @test TMLE.outcome(Ψ) == :y - @test TMLE.treatment_and_confounders(Ψ) == [:W₁, :W₂, :T₁] - @test TMLE.confounders_and_covariates(Ψ) == [:W₁, :W₂] - @test TMLE.allcolumns(Ψ) == [:W₁, :W₂, :T₁, :y] + scm=scm, + treatment=(T=1,), + outcome =:Y + ) + + @test TMLE.treatments(Ψ) == [:T] + @test TMLE.outcome(Ψ) == :Y + @test TMLE.confounders(Ψ) == (T=[:W₁, :W₂],) + identified, reasons = isidentified(Ψ, dataset) + @test identified === true + @test reasons == [] + ηs = TMLE.equations_to_fit(Ψ) + @test ηs === (scm.Y, scm.T) + # ATE + scm = SCM( + SE(:Z, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁]), + SE(:T₁, [:W₁₁, :W₁₂]), + SE(:T₂, [:W₂₁]), + ) + Ψ = ATE( + scm=scm, + treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), + outcome =:Z, + ) + @test TMLE.treatments(Ψ) == [:T₁, :T₂] + @test TMLE.outcome(Ψ) == :Z + @test TMLE.confounders(Ψ) == (T₁=[:W₁₁, :W₁₂], T₂=[:W₂₁]) + identified, reasons = isidentified(Ψ, dataset) + @test identified === false + expected_reasons = [ + "Outcome variable: Z is not in the dataset.", + "Treatment variable: T₁ is not in the dataset.", + "Treatment variable: T₂ is not in the dataset.", + "Confounding variable: W₁₂ is not in the dataset.", + "Confounding variable: W₂₁ is not in the dataset.", + "Confounding variable: W₁₁ is not in the dataset." + ] + @test reasons == expected_reasons + ηs = TMLE.equations_to_fit(Ψ) + @test ηs === (scm.Z, scm.T₁, scm.T₂) + # IATE + Ψ = IATE( + scm=scm, + treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), + outcome =:Z, + ) + @test TMLE.treatments(Ψ) == [:T₁, :T₂] + @test TMLE.outcome(Ψ) == :Z + @test TMLE.confounders(Ψ) == (T₁=[:W₁₁, :W₁₂], T₂=[:W₂₁]) + identified, reasons = isidentified(Ψ, dataset) + @test identified === false + @test reasons == expected_reasons + ηs = TMLE.equations_to_fit(Ψ) + @test ηs === (scm.Z, scm.T₁, scm.T₂) +end +@testset "Test fit!" begin + rng = StableRNG(123) + n = 100 + dataset = ( + Ycat = categorical(rand(rng, [0, 1], n)), + Ycont = rand(rng, n), + T₁ = categorical(rand(rng, [0, 1], n)), + T₂ = categorical(rand(rng, [0, 1], n)), + W₁₁ = rand(rng, n), + W₁₂ = rand(rng, n), + W₂₁ = rand(rng, n), + W₂₂ = rand(rng, n), + C = rand(rng, n) + ) + + scm = SCM( + SE(:Ycat, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], TreatmentTransformer() |> LinearBinaryClassifier()), + SE(:Ycont, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], TreatmentTransformer() |> LinearRegressor()), + SE(:T₁, [:W₁₁, :W₁₂], LinearBinaryClassifier()), + SE(:T₂, [:W₂₁, :W₂₂], LinearBinaryClassifier()), + ) + + # Fits CM required equations Ψ = CM( + scm=scm, treatment=(T₁=1,), - confounders=[:W₁, :W₂], - outcome =:y, - covariates=[:C₁] - ) - @test TMLE.treatments(Ψ) == [:T₁] - @test TMLE.outcome(Ψ) == :y - @test TMLE.treatment_and_confounders(Ψ) == [:W₁, :W₂, :T₁] - @test TMLE.confounders_and_covariates(Ψ) == [:W₁, :W₂, :C₁] - @test TMLE.allcolumns(Ψ) == [:W₁, :W₂, :C₁, :T₁, :y] + outcome=:Ycat + ) + fit!(Ψ, dataset, verbosity=0) + @test scm.Ycat.mach isa Machine + @test scm.T₁.mach isa Machine + @test scm.Ycont.mach isa Nothing + @test scm.T₂.mach isa Nothing + + # Fits ATE required equations that haven't been fitted yet. + Ψ = ATE( + scm=scm, + treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), + outcome =:Ycat, + ) + log_sequence = ( + (:info, "Structural Equation corresponding to variable Ycat already fitted, skipping. Set `force=true` to force refit."), + (:info, "Structural Equation corresponding to variable T₁ already fitted, skipping. Set `force=true` to force refit."), + (:info, "Fitting Structural Equation corresponding to variable T₂."), + ) + @test_logs log_sequence... fit!(Ψ, dataset, verbosity=1) + @test scm.Ycat.mach isa Machine + @test scm.T₁.mach isa Machine + @test scm.T₂.mach isa Machine + @test scm.Ycont.mach isa Nothing + + # Fits IATE required equations that haven't been fitted yet. + Ψ = IATE( + scm=scm, + treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), + outcome =:Ycont, + ) + log_sequence = ( + (:info, "Fitting Structural Equation corresponding to variable Ycont."), + (:info, "Structural Equation corresponding to variable T₁ already fitted, skipping. Set `force=true` to force refit."), + (:info, "Structural Equation corresponding to variable T₂ already fitted, skipping. Set `force=true` to force refit."), + ) + @test_logs log_sequence... fit!(Ψ, dataset, verbosity=1) + @test scm.Ycat.mach isa Machine + @test scm.T₁.mach isa Machine + @test scm.T₂.mach isa Machine + @test scm.Ycont.mach isa Machine end @testset "Test optimize_ordering" begin diff --git a/test/scm.jl b/test/scm.jl index af750776..fac1a4e5 100644 --- a/test/scm.jl +++ b/test/scm.jl @@ -108,6 +108,12 @@ end @test eq.mach isa Machine @test !isdefined(eq.mach, :data) end + + # Reset scm + reset!(scm) + for (key, eq) in equations(scm) + @test eq.mach isa Nothing + end end end From 79d365722d528978c81a6c376ad716dfba06cf84 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 4 Aug 2023 10:47:04 +0100 Subject: [PATCH 009/151] add another SCM constructor --- src/TMLE.jl | 9 +++++---- src/scm.jl | 28 +++++++++++++++++++++------- test/scm.jl | 22 ++++++++++++++++++++-- 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index bb115e95..ca5756b3 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -24,15 +24,16 @@ import AbstractDifferentiation as AD # EXPORTS # ############################################################################# -export SE, StructuralEquation, outcome, treatments -export StructuralCausalModel, SCM, StaticConfoundedModel, equations -export fit!, reset! +export SE, StructuralEquation +export StructuralCausalModel, SCM, StaticConfoundedModel export NuisanceSpec, TMLECache, update!, CM, ATE, IATE, isidentified +export treatments, outcome, confounders, parents, fit!, reset!, \ + isidentified, equations export tmle, tmle! export var, estimate, initial_estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose export estimands_from_yaml, estimands_to_yaml, optimize_ordering, optimize_ordering! -export TreatmentTransformer +export TreatmentTransformer # ############################################################################# # INCLUDES diff --git a/src/scm.jl b/src/scm.jl index b244ddf0..409ee6d7 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -23,8 +23,9 @@ const SE = StructuralEquation function MLJBase.fit!(eq::SE, dataset; verbosity=1, cache=true, force=false) if eq.mach === nothing || force === true verbosity >= 1 && @info(string("Fitting Structural Equation corresponding to variable ", outcome(eq), ".")) - X = selectcols(dataset, parents(eq)) - y = Tables.getcolumn(dataset, outcome(eq)) + data = nomissing(dataset, vcat(parents(eq), outcome(eq))) + X = selectcols(data, parents(eq)) + y = Tables.getcolumn(data, outcome(eq)) mach = machine(eq.model, X, y, cache=cache) MLJBase.fit!(mach, verbosity=verbosity-1) eq.mach = mach @@ -125,8 +126,8 @@ end ### StaticConfoundedModel ### ##################################################################### -vcat_covariates(covariates::Nothing, treatment, confounders) = vcat(treatment, confounders) -vcat_covariates(covariates, treatment, confounders) = vcat(covariates, treatment, confounders) +vcat_covariates(treatment, confounders, covariates::Nothing) = vcat(treatment, confounders) +vcat_covariates(treatment, confounders, covariates) = vcat(treatment, confounders, covariates) function StaticConfoundedModel( outcome::Symbol, treatment::Symbol, confounders::Union{Symbol, AbstractVector{Symbol}}; @@ -134,15 +135,28 @@ function StaticConfoundedModel( outcome_spec = LinearRegressor(), treatment_spec = LinearBinaryClassifier() ) - Yeq = StructuralEquation( + Yeq = SE( outcome, - vcat_covariates(covariates, treatment, confounders), + vcat_covariates(treatment, confounders, covariates), outcome_spec ) - Teq = StructuralEquation( + Teq = SE( treatment, vcat(confounders), treatment_spec ) return StructuralCausalModel(Yeq, Teq) +end + +function StaticConfoundedModel( + outcomes::Vector{Symbol}, + treatments::Vector{Symbol}, + confounders::Union{Symbol, AbstractVector{Symbol}}; + covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, + outcome_spec = LinearRegressor(), + treatment_spec = LinearBinaryClassifier() + ) + Yequations = (SE(outcome, vcat_covariates(treatments, confounders, covariates), outcome_spec) for outcome in outcomes) + Tequations = (SE(treatment, vcat(confounders), treatment_spec) for treatment in treatments) + return SCM(Yequations..., Tequations...) end \ No newline at end of file diff --git a/test/scm.jl b/test/scm.jl index fac1a4e5..7a048dba 100644 --- a/test/scm.jl +++ b/test/scm.jl @@ -41,7 +41,7 @@ end @testset "Test StaticConfoundedModel" begin # With covariates scm = StaticConfoundedModel(:Y, :T, [:W₁, :W₂], covariates=[:C₁]) - @test TMLE.parents(scm, :Y) == [:C₁, :T, :W₁, :W₂] + @test TMLE.parents(scm, :Y) == [:T, :W₁, :W₂, :C₁] @test TMLE.parents(scm, :T) == [:W₁, :W₂] @test scm.Y.model isa LinearRegressor @test scm.T.model isa LinearBinaryClassifier @@ -56,8 +56,26 @@ end # With 1 covariate scm = StaticConfoundedModel(:Y, :T, :W₁, covariates=:C₁) - @test TMLE.parents(scm, :Y) == [:C₁, :T, :W₁] + @test TMLE.parents(scm, :Y) == [:T, :W₁, :C₁] @test TMLE.parents(scm, :T) == [:W₁] + + # With multiple outcomes and treatments + scm = StaticConfoundedModel( + [:Y₁, :Y₂, :Y₃], + [:T₁, :T₂], + [:W₁, :W₂, :W₃], + covariates=[:C] + ) + + Yparents = [:T₁, :T₂, :W₁, :W₂, :W₃, :C] + @test parents(scm.Y₁) == Yparents + @test parents(scm.Y₂) == Yparents + @test parents(scm.Y₃) == Yparents + + Tparents = [:W₁, :W₂, :W₃] + @test parents(scm.T₁) == Tparents + @test parents(scm.T₂) == Tparents + end @testset "Test fit!" begin From 8e1749f9b7d4b4a9ad13871912731711f939d7a5 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 4 Aug 2023 11:51:02 +0100 Subject: [PATCH 010/151] updated default Y model in scm --- src/scm.jl | 16 ++++++++-------- test/scm.jl | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/scm.jl b/src/scm.jl index 409ee6d7..f5003aa8 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -132,18 +132,18 @@ vcat_covariates(treatment, confounders, covariates) = vcat(treatment, confounder function StaticConfoundedModel( outcome::Symbol, treatment::Symbol, confounders::Union{Symbol, AbstractVector{Symbol}}; covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, - outcome_spec = LinearRegressor(), - treatment_spec = LinearBinaryClassifier() + outcome_model = TreatmentTransformer() |> LinearRegressor(), + treatment_model = LinearBinaryClassifier() ) Yeq = SE( outcome, vcat_covariates(treatment, confounders, covariates), - outcome_spec + outcome_model ) Teq = SE( treatment, vcat(confounders), - treatment_spec + treatment_model ) return StructuralCausalModel(Yeq, Teq) end @@ -153,10 +153,10 @@ function StaticConfoundedModel( treatments::Vector{Symbol}, confounders::Union{Symbol, AbstractVector{Symbol}}; covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, - outcome_spec = LinearRegressor(), - treatment_spec = LinearBinaryClassifier() + outcome_model = TreatmentTransformer() |> LinearRegressor(), + treatment_model = LinearBinaryClassifier() ) - Yequations = (SE(outcome, vcat_covariates(treatments, confounders, covariates), outcome_spec) for outcome in outcomes) - Tequations = (SE(treatment, vcat(confounders), treatment_spec) for treatment in treatments) + Yequations = (SE(outcome, vcat_covariates(treatments, confounders, covariates), outcome_model) for outcome in outcomes) + Tequations = (SE(treatment, vcat(confounders), treatment_model) for treatment in treatments) return SCM(Yequations..., Tequations...) end \ No newline at end of file diff --git a/test/scm.jl b/test/scm.jl index 7a048dba..e72b6a98 100644 --- a/test/scm.jl +++ b/test/scm.jl @@ -43,12 +43,12 @@ end scm = StaticConfoundedModel(:Y, :T, [:W₁, :W₂], covariates=[:C₁]) @test TMLE.parents(scm, :Y) == [:T, :W₁, :W₂, :C₁] @test TMLE.parents(scm, :T) == [:W₁, :W₂] - @test scm.Y.model isa LinearRegressor + @test scm.Y.model isa ProbabilisticPipeline @test scm.T.model isa LinearBinaryClassifier @test scm.T.model.fit_intercept === true # Without covariates - scm = StaticConfoundedModel(:Y, :T, :W₁, outcome_spec=LinearBinaryClassifier(), treatment_spec=LinearBinaryClassifier(fit_intercept=false)) + scm = StaticConfoundedModel(:Y, :T, :W₁, outcome_model=LinearBinaryClassifier(), treatment_model=LinearBinaryClassifier(fit_intercept=false)) @test TMLE.parents(scm, :Y) == [:T, :W₁] @test TMLE.parents(scm, :T) == [:W₁] @test scm.Y.model isa LinearBinaryClassifier From 4a7fb29425f21216177c5baa847204a53c41d775 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 4 Aug 2023 13:20:23 +0100 Subject: [PATCH 011/151] update utils --- src/utils.jl | 54 +++++++---- test/utils.jl | 252 ++++++++++++++++++++++++-------------------------- 2 files changed, 157 insertions(+), 149 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 8dd85a54..693999b8 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -46,32 +46,33 @@ end ncases(value, Ψ::Estimand) = sum(value[i] == Ψ.treatment[i].case for i in eachindex(value)) -function indicator_fns(Ψ::IATE, f::Function) +""" +""" +function indicator_fns(Ψ::IATE) N = length(treatments(Ψ)) key_vals = Pair[] for cf in Iterators.product((values(Ψ.treatment[T]) for T in treatments(Ψ))...) - push!(key_vals, f(cf) => float((-1)^(N - ncases(cf, Ψ)))) + push!(key_vals, cf => float((-1)^(N - ncases(cf, Ψ)))) end return Dict(key_vals...) end -indicator_fns(Ψ::CM, f::Function) = Dict(f(values(Ψ.treatment)) => 1.) +indicator_fns(Ψ::CM) = Dict(values(Ψ.treatment) => 1.) -function indicator_fns(Ψ::ATE, f::Function) +function indicator_fns(Ψ::ATE) case = [] control = [] for treatment in Ψ.treatment push!(case, treatment.case) push!(control, treatment.control) end - return Dict(f(Tuple(case)) => 1., f(Tuple(control)) => -1.) + return Dict(Tuple(case) => 1., Tuple(control) => -1.) end -function indicator_values(indicators, jointT) - indic = zeros(Float64, nrows(jointT)) - for i in eachindex(jointT) - val = jointT[i] - indic[i] = get(indicators, val, 0.) +function indicator_values(indicators, T) + indic = zeros(Float64, nrows(T)) + for (index, row) in enumerate(Tables.namedtupleiterator(T)) + indic[index] = get(indicators, values(row), 0.) end return indic end @@ -92,23 +93,38 @@ end compute_offset(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = expected_value(ŷ) compute_offset(ŷ::AbstractVector{<:Real}) = expected_value(ŷ) -function balancing_weights(G::Machine, W, jointT; threshold=1e-8) - ŷ = MLJBase.predict(G, W) - d = pdf.(ŷ, jointT) - plateau!(d, threshold) - return 1. ./ d +function balancing_weights(scm, W, T; threshold=1e-8) + density = ones(nrows(T)) + for colname ∈ Tables.columnnames(T) + mach = scm[colname].mach + ŷ = MLJBase.predict(mach, W[colname]) + density .*= pdf.(ŷ, Tables.getcolumn(T, colname)) + end + plateau!(density, threshold) + return 1. ./ density end """ clever_covariate_and_weights(jointT, W, G, indicator_fns; threshold=1e-8, weighted_fluctuation=false) Computes the clever covariate and weights that are used to fluctuate the initial Q. -See [here](https://lendle.github.io/TargetedLearning.jl/user-guide/ctmle/#fluctuating-barq_n). + +if `weighted_fluctuation = false`: + +- ``clever_covariate(t, w) = \\frac{SpecialIndicator(t)}{p(t|w)}`` +- ``weight(t, w) = 1`` + +if `weighted_fluctuation = true`: + +- ``clever_covariate(t, w) = SpecialIndicator(t)`` +- ``weight(t, w) = \\frac{1}{p(t|w)}`` + +where SpecialIndicator(t) is defined in `indicator_fns`. """ -function clever_covariate_and_weights(jointT, W, G, indicator_fns; threshold=1e-8, weighted_fluctuation=false) +function clever_covariate_and_weights(scm::SCM, T, W, indicator_fns; threshold=1e-8, weighted_fluctuation=false) # Compute the indicator values - indic_vals = TMLE.indicator_values(indicator_fns, jointT) - weights = balancing_weights(G, W, jointT, threshold=threshold) + indic_vals = TMLE.indicator_values(indicator_fns, T) + weights = balancing_weights(scm, W, T, threshold=threshold) if weighted_fluctuation return indic_vals, weights end diff --git a/test/utils.jl b/test/utils.jl index 588ca54a..6a3813fd 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -12,79 +12,76 @@ using MLJGLMInterface: LinearBinaryClassifier using MLJLinearModels using MLJModels -function test_params_match(estimands, expected_params) - for (param, expected_param) in zip(estimands, expected_params) - @test typeof(param) == typeof(expected_param) - @test param.outcome == expected_param.outcome - @test param.treatment == expected_param.treatment - @test param.confounders == expected_param.confounders - @test param.covariates == expected_param.covariates - end -end - -@testset "Test joint_treatment" begin - T = ( - T₁ = categorical([0, 1, 2, 1]), - T₂ = categorical(["a", "b", "a", "b"]) +@testset "Test indicator_fns & indicator_values" begin + scm = StaticConfoundedModel( + [:Y], + [:T₁, :T₂, :T₃], + [:W] ) - jointT = TMLE.joint_treatment(T) - @test jointT == categorical(["0_&_a", "1_&_b", "2_&_a", "1_&_b"]) - - T = ( - T₁ = categorical([0, 1, 2, 1]), + dataset = ( + W = [1, 2, 3, 4, 5, 6, 7, 8], + T₁ = ["A", "B", "A", "B", "A", "B", "A", "B"], + T₂ = [0, 0, 1, 1, 0, 0, 1, 1], + T₃ = ["C", "C", "C", "C", "D", "D", "D", "D"], + Y = [1, 1, 1, 1, 1, 1, 1, 1] ) - jointT = TMLE.joint_treatment(T) - @test jointT == categorical(["0", "1", "2", "1"]) -end - - -@testset "Test indicator_fns" begin # Conditional Mean Ψ = CM( - outcome=:y, + scm=scm, + outcome=:Y, treatment=(T₁="A", T₂=1), - confounders=[:W] ) - - @test TMLE.indicator_fns(Ψ, TMLE.joint_name) == Dict{String, Float64}("A_&_1" => 1) + indicator_fns = TMLE.indicator_fns(Ψ) + @test indicator_fns == Dict(("A", 1) => 1.) + indic_values = TMLE.indicator_values(indicator_fns, treatments(dataset, Ψ)) + @test indic_values == [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0] # ATE Ψ = ATE( + scm=scm, outcome=:y, treatment=(T₁=(case="A", control="B"), T₂=(control=0, case=1)), - confounders=[:W] ) - @test TMLE.indicator_fns(Ψ, TMLE.joint_name) == Dict{String, Float64}( - "A_&_1" => 1, - "B_&_0" => -1 + indicator_fns = TMLE.indicator_fns(Ψ) + @test indicator_fns == Dict( + ("A", 1) => 1.0, + ("B", 0) => -1.0 ) + indic_values = TMLE.indicator_values(indicator_fns, treatments(dataset, Ψ)) + @test indic_values == [0.0, -1.0, 1.0, 0.0, 0.0, -1.0, 1.0, 0.0] # 2-points IATE Ψ = IATE( + scm=scm, outcome=:y, treatment=(T₁=(case="A", control="B"), T₂=(case=1, control=0)), - confounders=[:W] ) - @test TMLE.indicator_fns(Ψ, TMLE.joint_name) == Dict{String, Float64}( - "A_&_1" => 1, - "A_&_0" => -1, - "B_&_1" => -1, - "B_&_0" => 1 + indicator_fns = TMLE.indicator_fns(Ψ) + @test indicator_fns == Dict( + ("A", 1) => 1.0, + ("A", 0) => -1.0, + ("B", 1) => -1.0, + ("B", 0) => 1.0 ) + indic_values = TMLE.indicator_values(indicator_fns, treatments(dataset, Ψ)) + @test indic_values == [-1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0] # 3-points IATE Ψ = IATE( + scm=scm, outcome=:y, treatment=(T₁=(case="A", control="B"), T₂=(case=1, control=0), T₃=(control="D", case="C")), - confounders=[:W] ) - @test TMLE.indicator_fns(Ψ, TMLE.joint_name) == Dict{String, Float64}( - "A_&_1_&_D" => -1, - "A_&_1_&_C" => 1, - "B_&_0_&_D" => -1, - "B_&_0_&_C" => 1, - "B_&_1_&_C" => -1, - "A_&_0_&_D" => 1, - "B_&_1_&_D" => 1, - "A_&_0_&_C" => -1 + indicator_fns = TMLE.indicator_fns(Ψ) + @test indicator_fns == Dict( + ("A", 1, "D") => -1.0, + ("A", 1, "C") => 1.0, + ("B", 0, "D") => -1.0, + ("B", 0, "C") => 1.0, + ("B", 1, "C") => -1.0, + ("A", 0, "D") => 1.0, + ("B", 1, "D") => 1.0, + ("A", 0, "C") => -1.0 ) + indic_values = TMLE.indicator_values(indicator_fns, treatments(dataset, Ψ)) + @test indic_values == [-1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0] end @testset "Test expected_value" begin @@ -116,27 +113,6 @@ end @test expectation == ŷ end -@testset "Test indicator_values" begin - indicators = Dict( - "b_&_c_&_true" => -1, - "a_&_c_&_true" => 1, - "b_&_d_&_false" => -1, - "b_&_c_&_false" => 1, - "a_&_d_&_true" => -1, - "a_&_c_&_false" => -1, - "a_&_d_&_false" => 1, - "b_&_d_&_true" => 1 - ) - jointT = categorical([ - "b_&_c_&_true", "a_&_c_&_true", "b_&_d_&_false", - "b_&_c_&_false", "a_&_d_&_true", "a_&_c_&_false", - "a_&_d_&_false", "b_&_d_&_true", "q_&_d_&_false" - ]) - # The las combination does not appear in the indicators - @test TMLE.indicator_values(indicators, jointT) == - [-1, 1, -1, 1, -1, -1, 1, 1, 0] -end - @testset "Test counterfactualTreatment" begin vals = (true, "a") @@ -151,93 +127,109 @@ end ) @test isordered(cfT.t₁) @test !isordered(cfT.t₂) - end @testset "Test clever_covariate_and_weights" begin # First case: 1 categorical variable # Using a trivial classifier # that outputs the proportions of of the classes - jointT = categorical(["a", "b", "c", "a", "a", "b", "a"]) - W = MLJBase.table(rand(7, 3)) - - Gmach = machine(ConstantClassifier(), - W, - jointT) - fit!(Gmach, verbosity=0) - + scm = StaticConfoundedModel( + :Y, :T, [:W₁, :W₂, :W₃], treatment_model = ConstantClassifier() + ) Ψ = ATE( - outcome =:y, - treatment=(t₁=(case="a", control="b"),), - confounders = [:x1, :x2, :x3] + scm=scm, + outcome =:Y, + treatment=(T=(case="a", control="b"),), + ) + dataset = ( + T = categorical(["a", "b", "c", "a", "a", "b", "a"]), + Y = rand(7), + W₁ = rand(7), + W₂ = rand(7), + W₃ = rand(7) ) - indicator_fns = TMLE.indicator_fns(Ψ, TMLE.joint_name) + fit!(scm, dataset, verbosity=0) + indicator_fns = TMLE.indicator_fns(Ψ) + T = treatments(dataset, Ψ) + @test T == (T=dataset.T,) + W = confounders(dataset, Ψ) + @test W == (T=(W₁=dataset.W₁, W₂=dataset.W₂, W₃=dataset.W₃),) - bw = TMLE.balancing_weights(Gmach, W, jointT) - cov, w = TMLE.clever_covariate_and_weights(jointT, W, Gmach, indicator_fns, weighted_fluctuation=true) + bw = TMLE.balancing_weights(scm, W, T) + cov, w = TMLE.clever_covariate_and_weights(scm, T, W, indicator_fns, weighted_fluctuation=true) @test cov == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] @test w == bw == [1.75, 3.5, 7.0, 1.75, 1.75, 3.5, 1.75] - cov, w = TMLE.clever_covariate_and_weights(jointT, W, Gmach, indicator_fns) + cov, w = TMLE.clever_covariate_and_weights(scm, T, W, indicator_fns) @test cov == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] # Second case: 2 binary variables # Using a trivial classifier # that outputs the proportions of of the classes - jointT = categorical(["1_&_1", "0_&_1", "0_&_1", "1_&_1", "1_&_1", "1_&_0", "0_&_0"]) - W = MLJBase.table(rand(7, 3)) - - Gmach = machine(ConstantClassifier(), - W, - jointT) - fit!(Gmach, verbosity=0) + scm = StaticConfoundedModel( + [:Y], [:T₁, :T₂], [:W₁, :W₂, :W₃], treatment_model = ConstantClassifier() + ) Ψ = IATE( - outcome =:y, - treatment=(t₁=(case=1, control=0), t₂=(case=1, control=0)), - confounders = [:x1, :x2, :x3] + scm=scm, + outcome =:Y, + treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), ) - indicator_fns = TMLE.indicator_fns(Ψ, TMLE.joint_name) - - cov, w = TMLE.clever_covariate_and_weights(jointT, W, Gmach, indicator_fns) - @test cov == [2.3333333333333335, - -3.5, - -3.5, - 2.3333333333333335, - 2.3333333333333335, - -7.0, - 7.0] + dataset = ( + T₁ = categorical([1, 0, 0, 1, 1, 1, 0]), + T₂ = categorical([1, 1, 1, 1, 1, 0, 0]), + Y = rand(7), + W₁ = rand(7), + W₂ = rand(7), + W₃ = rand(7) + ) + fit!(scm, dataset, verbosity=0) + indicator_fns = TMLE.indicator_fns(Ψ) + T = TMLE.treatments(dataset, Ψ) + @test T == (T₁ = dataset.T₁, T₂ = dataset.T₂) + W = TMLE.confounders(dataset, Ψ) + @test W == ( + T₁ = (W₁ = dataset.W₁, W₂ = dataset.W₂, W₃ = dataset.W₃), + T₂ = (W₁ = dataset.W₁, W₂ = dataset.W₂, W₃ = dataset.W₃) + ) + cov, w = TMLE.clever_covariate_and_weights(scm, T, W, indicator_fns) + @test cov ≈ [2.45, -3.266, -3.266, 2.45, 2.45, -6.125, 8.166] atol=1e-2 # Third case: 3 mixed categorical variables # Using a trivial classifier # that outputs the proportions of of the classes - jointT = categorical( - ["a_&_3_&_true", "a_&_2_&_false", "b_&_1_&_true", - "b_&_1_&_false", "c_&_2_&_false", "b_&_2_&_false", - "b_&_2_&_false"]) - W = MLJBase.table(rand(7, 3)) - - Gmach = machine(ConstantClassifier(), - W, - jointT) - fit!(Gmach, verbosity=0) + scm = StaticConfoundedModel( + [:Y], [:T₁, :T₂, :T₃], [:W], treatment_model = ConstantClassifier() + ) Ψ = IATE( - outcome =:y, - treatment=(t₁=(case="a", control="b"), - t₂=(case=1, control=2), - t₃=(case=true, control=false)), - confounders = [:x1, :x2, :x3] + scm=scm, + outcome =:Y, + treatment=(T₁=(case="a", control="b"), + T₂=(case=1, control=2), + T₃=(case=true, control=false)), + ) + dataset = ( + T₁ = categorical(["a", "a", "b", "b", "c", "b", "b"]), + T₂ = categorical([3, 2, 1, 1, 2, 2, 2], ordered=true), + T₃ = categorical([true, false, true, false, false, false, false], ordered=true), + Y = rand(7), + W = rand(7), + ) + + fit!(scm, dataset, verbosity=0) + indicator_fns = TMLE.indicator_fns(Ψ) + T = TMLE.treatments(dataset, Ψ) + @test T == (T₁ = dataset.T₁, T₂ = dataset.T₂, T₃ = dataset.T₃) + W = TMLE.confounders(dataset, Ψ) + @test W == ( + T₁ = (W = dataset.W,), + T₂ = (W = dataset.W,), + T₃ = (W = dataset.W,) ) - indicator_fns = TMLE.indicator_fns(Ψ, TMLE.joint_name) + indicator_fns = TMLE.indicator_fns(Ψ) - cov, w = TMLE.clever_covariate_and_weights(jointT, W, Gmach, indicator_fns) - @test cov == [0, - 7.0, - -7, - 7, - 0, - -3.5, - -3.5] + cov, w = TMLE.clever_covariate_and_weights(scm, T, W, indicator_fns) + @test cov ≈ [0, 8.575, -21.4375, 8.575, 0, -4.2875, -4.2875] atol=1e-3 end @testset "Test compute_offset" begin From 6f090cdc5d39ddb2d4529c0f465012eb09390730 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 4 Aug 2023 14:40:04 +0100 Subject: [PATCH 012/151] update utils --- src/utils.jl | 3 ++ test/utils.jl | 84 +++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 64 insertions(+), 23 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 693999b8..90e7e18c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -93,6 +93,9 @@ end compute_offset(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = expected_value(ŷ) compute_offset(ŷ::AbstractVector{<:Real}) = expected_value(ŷ) +compute_offset(Ψ::CMCompositeEstimand) = + compute_offset(MLJBase.predict(Ψ.scm[outcome(Ψ)].mach)) + function balancing_weights(scm, W, T; threshold=1e-8) density = ones(nrows(T)) for colname ∈ Tables.columnnames(T) diff --git a/test/utils.jl b/test/utils.jl index 6a3813fd..1ea1aecf 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -117,24 +117,25 @@ end @testset "Test counterfactualTreatment" begin vals = (true, "a") T = ( - t₁ = categorical([true, false, false], ordered=true), - t₂ = categorical(["a", "a", "c"]) + T₁ = categorical([true, false, false], ordered=true), + T₂ = categorical(["a", "a", "c"]) ) cfT = TMLE.counterfactualTreatment(vals, T) @test cfT == ( - t₁ = categorical([true, true, true]), - t₂ = categorical(["a", "a", "a"]) + T₁ = categorical([true, true, true]), + T₂ = categorical(["a", "a", "a"]) ) - @test isordered(cfT.t₁) - @test !isordered(cfT.t₂) + @test isordered(cfT.T₁) + @test !isordered(cfT.T₂) end -@testset "Test clever_covariate_and_weights" begin - # First case: 1 categorical variable - # Using a trivial classifier - # that outputs the proportions of of the classes +@testset "Test compute_offset, clever_covariate_and_weights and tmle_step" begin + ### In all cases, models are "constant" models + ## First case: 1 Treamtent variable scm = StaticConfoundedModel( - :Y, :T, [:W₁, :W₂, :W₃], treatment_model = ConstantClassifier() + :Y, :T, [:W₁, :W₂, :W₃], + treatment_model = ConstantClassifier(), + outcome_model = ConstantRegressor() ) Ψ = ATE( scm=scm, @@ -143,7 +144,7 @@ end ) dataset = ( T = categorical(["a", "b", "c", "a", "a", "b", "a"]), - Y = rand(7), + Y = [1., 2., 3, 4, 5, 6, 7], W₁ = rand(7), W₂ = rand(7), W₃ = rand(7) @@ -154,21 +155,37 @@ end @test T == (T=dataset.T,) W = confounders(dataset, Ψ) @test W == (T=(W₁=dataset.W₁, W₂=dataset.W₂, W₃=dataset.W₃),) + + @test TMLE.compute_offset(Ψ) == repeat([mean(dataset.Y)], 7) bw = TMLE.balancing_weights(scm, W, T) cov, w = TMLE.clever_covariate_and_weights(scm, T, W, indicator_fns, weighted_fluctuation=true) @test cov == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] @test w == bw == [1.75, 3.5, 7.0, 1.75, 1.75, 3.5, 1.75] + + weighted_mach = TMLE.tmle_step(Ψ, weighted_fluctuation=true) + @test fitted_params(weighted_mach) == ( + features = [:covariate], + coef = [0.12500000000000003], + intercept = 0.0 + ) cov, w = TMLE.clever_covariate_and_weights(scm, T, W, indicator_fns) @test cov == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] - # Second case: 2 binary variables - # Using a trivial classifier - # that outputs the proportions of of the classes + unweighted_mach = TMLE.tmle_step(Ψ, weighted_fluctuation=false) + @test fitted_params(unweighted_mach) == ( + features = [:covariate], + coef = [0.047619047619047616], + intercept = 0.0 + ) + + ## Second case: 2 Treatment variables scm = StaticConfoundedModel( - [:Y], [:T₁, :T₂], [:W₁, :W₂, :W₃], treatment_model = ConstantClassifier() + [:Y], [:T₁, :T₂], [:W₁, :W₂, :W₃], + treatment_model = ConstantClassifier(), + outcome_model = ConstantClassifier() ) Ψ = IATE( scm=scm, @@ -178,7 +195,7 @@ end dataset = ( T₁ = categorical([1, 0, 0, 1, 1, 1, 0]), T₂ = categorical([1, 1, 1, 1, 1, 0, 0]), - Y = rand(7), + Y = categorical([1, 1, 1, 1, 0, 0, 0]), W₁ = rand(7), W₂ = rand(7), W₃ = rand(7) @@ -192,14 +209,25 @@ end T₁ = (W₁ = dataset.W₁, W₂ = dataset.W₂, W₃ = dataset.W₃), T₂ = (W₁ = dataset.W₁, W₂ = dataset.W₂, W₃ = dataset.W₃) ) + # Because the outcome is binary, the offset is the logit + @test TMLE.compute_offset(Ψ) == repeat([0.28768207245178085], 7) + cov, w = TMLE.clever_covariate_and_weights(scm, T, W, indicator_fns) @test cov ≈ [2.45, -3.266, -3.266, 2.45, 2.45, -6.125, 8.166] atol=1e-2 + @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] - # Third case: 3 mixed categorical variables - # Using a trivial classifier - # that outputs the proportions of of the classes + mach = TMLE.tmle_step(Ψ) + @test fitted_params(mach) == ( + features = [:covariate], + coef = [-0.0945122832139119], + intercept = 0.0 + ) + + ## Third case: 3 Treatment variables scm = StaticConfoundedModel( - [:Y], [:T₁, :T₂, :T₃], [:W], treatment_model = ConstantClassifier() + [:Y], [:T₁, :T₂, :T₃], [:W], + treatment_model = ConstantClassifier(), + outcome_model = DeterministicConstantRegressor() ) Ψ = IATE( scm=scm, @@ -212,8 +240,8 @@ end T₁ = categorical(["a", "a", "b", "b", "c", "b", "b"]), T₂ = categorical([3, 2, 1, 1, 2, 2, 2], ordered=true), T₃ = categorical([true, false, true, false, false, false, false], ordered=true), - Y = rand(7), - W = rand(7), + Y = [1., 2., 3, 4, 5, 6, 7], + W = rand(7), ) fit!(scm, dataset, verbosity=0) @@ -228,8 +256,18 @@ end ) indicator_fns = TMLE.indicator_fns(Ψ) + @test TMLE.compute_offset(Ψ) == repeat([4.0], 7) + cov, w = TMLE.clever_covariate_and_weights(scm, T, W, indicator_fns) @test cov ≈ [0, 8.575, -21.4375, 8.575, 0, -4.2875, -4.2875] atol=1e-3 + @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + + mach = TMLE.tmle_step(Ψ) + @test fitted_params(mach) == ( + features = [:covariate], + coef = [-0.02665556018325698], + intercept = 0.0 + ) end @testset "Test compute_offset" begin From d0ba05c1192179c675b9eaf861bd77b4e641b6d0 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 4 Aug 2023 15:54:26 +0100 Subject: [PATCH 013/151] update cache --- src/TMLE.jl | 5 +- src/cache.jl | 314 ++++++------------------------------ src/estimands.jl | 100 +++--------- src/utils.jl | 4 +- test/estimands.jl | 4 +- test/non_regression_test.jl | 28 ++-- test/utils.jl | 50 ++---- 7 files changed, 105 insertions(+), 400 deletions(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index ca5756b3..0d465aed 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -26,9 +26,8 @@ import AbstractDifferentiation as AD export SE, StructuralEquation export StructuralCausalModel, SCM, StaticConfoundedModel -export NuisanceSpec, TMLECache, update!, CM, ATE, IATE, isidentified -export treatments, outcome, confounders, parents, fit!, reset!, \ - isidentified, equations +export CM, ATE, IATE +export parents, fit!, reset!, isidentified, equations export tmle, tmle! export var, estimate, initial_estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose diff --git a/src/cache.jl b/src/cache.jl index fdc7b9d5..5222aefe 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -1,27 +1,3 @@ -""" -This structure is used as a cache for: - - data: data changes - - η: The nuisance estimands estimators - - Ψ: The estimand of interest - - η_spec: The specification of the learning algorithms used to estimate the nuisance estimands - -In many cases, if some of those fields do not change, the nuisance estimands do not have to be estimated again hence -saving computations. -""" -mutable struct TMLECache - data - η::NuisanceEstimands - Ψ::Estimand - η_spec::NuisanceSpec - function TMLECache(dataset) - data = Dict{Symbol, Any}(:source => dataset) - η = NuisanceEstimands(nothing, nothing, nothing, nothing) - new(data, η) - end -end - -Base.show(io::IO, cache::TMLECache) = println(typeof(cache)) - function check_treatment_settings(settings::NamedTuple, levels, treatment_name) for (key, val) in zip(keys(settings), settings) @@ -39,7 +15,7 @@ function check_treatment_settings(setting, levels, treatment_name) " in Ψ does not match any level of the corresponding variable in the dataset: ", string.(levels)))) end -function check_treatment_values(cache::TMLECache, Ψ::Estimand) +function check_treatment_values(cache, Ψ::Estimand) for treatment_name in treatments(Ψ) treatment_levels = string.(levels(Tables.getcolumn(cache.data[:source], treatment_name))) treatment_settings = getproperty(Ψ.treatment, treatment_name) @@ -47,261 +23,66 @@ function check_treatment_values(cache::TMLECache, Ψ::Estimand) end end - -function update!(cache::TMLECache, Ψ::Estimand) - check_treatment_values(cache, Ψ) - any_variable_changed = false - if !isdefined(cache, :Ψ) - any_variable_changed = true - else - if keys(cache.Ψ.treatment) != keys(Ψ.treatment) - cache.η.G = nothing - cache.η.Q = nothing - cache.η.H = nothing - any_variable_changed = true - end - if cache.Ψ.confounders != Ψ.confounders - cache.η.G = nothing - cache.η.Q = nothing - any_variable_changed = true - end - if cache.Ψ.covariates != Ψ.covariates - cache.η.Q = nothing - any_variable_changed = true - end - if cache.Ψ.outcome != Ψ.outcome - cache.η.Q = nothing - any_variable_changed = true - end - cache.η.F = nothing - end - # Update no missing dataset - if any_variable_changed - cache.data[:no_missing] = nomissing(cache.data[:source], allcolumns(Ψ)) - end - # Update indicator functions - cache.data[:indicators_str] = indicator_fns(Ψ, joint_name) - cache.data[:indicators_tuple] = indicator_fns(Ψ, x -> x) - # Update Ψ - cache.Ψ = Ψ - - return cache -end - -function update!(cache::TMLECache, η_spec::NuisanceSpec) - if isdefined(cache, :η_spec) - if cache.η_spec.G != η_spec.G - cache.η.G = nothing - end - if cache.η_spec.Q != η_spec.Q - cache.η.Q = nothing - end - cache.η.F = nothing - end - cache.η_spec = η_spec - return cache -end - -function update!(cache::TMLECache, Ψ::Estimand, η_spec::NuisanceSpec) - update!(cache, Ψ) - update!(cache, η_spec) -end - -const TMLE_ARGS_DOCS = """ -# Arguments - -- Ψ: The estimand of interest -- η_spec: The specification for learning `Q_0` and `G_0` -- dataset: A tabular dataset respecting the Table.jl interface -- verbosity: The logging level -- threshold: To avoid small values of Ĝ to cause the "clever covariate" to explode -- weighted_fluctuation: Fits a weighted fluctuation -""" - -""" - tmle(Ψ::Estimand, η_spec::NuisanceSpec, dataset; kwargs...) - -Build a TMLECache struct and runs TMLE. - -$TMLE_ARGS_DOCS -""" -function tmle(Ψ::Estimand, η_spec::NuisanceSpec, dataset; kwargs...) - cache = TMLECache(dataset) - update!(cache, Ψ, η_spec) - return tmle!(cache; kwargs...) -end - -""" - tmle!(cache::TMLECache; verbosity=1, threshold=1e-8, weighted_fluctuation=false) - -Runs TMLE using the provided TMLECache. - -$TMLE_ARGS_DOCS -""" -function tmle!(cache::TMLECache; verbosity=1, threshold=1e-8, weighted_fluctuation=false) - # Initial fit of the nuisance estimands - verbosity >= 1 && @info "Fitting the nuisance estimands..." - TMLE.fit_nuisance!(cache, verbosity=verbosity) - +function tmle(Ψ::Estimand, dataset; verbosity=1, cache=true, force=false, threshold=1e-8, weighted_fluctuation=false) + # Initial fit of the SCM's equations + verbosity >= 1 && @info "Fitting the required equations..." + fit!(Ψ, dataset; verbosity=verbosity, cache=cache, force=force) # TMLE step - verbosity >= 1 && @info "Targeting the nuisance estimands..." - tmle_step!(cache, verbosity=verbosity, threshold=threshold, weighted_fluctuation=weighted_fluctuation) - + verbosity >= 1 && @info "Performing TMLE..." + fluctuation_mach = tmle_step(Ψ, verbosity=verbosity, threshold=threshold, weighted_fluctuation=weighted_fluctuation) # Estimation results after TMLE - IC, Ψ̂, ICᵢ, Ψ̂ᵢ = TMLE.gradient_and_estimates(cache, threshold=threshold, weighted_fluctuation=weighted_fluctuation) + IC, Ψ̂, ICᵢ, Ψ̂ᵢ = gradient_and_estimates(Ψ, fluctuation_mach, threshold=threshold, weighted_fluctuation=weighted_fluctuation) tmle_result = ALEstimate(Ψ̂, IC) one_step_result = ALEstimate(Ψ̂ᵢ + mean(ICᵢ), ICᵢ) verbosity >= 1 && @info "Done." - return TMLEResult(cache.Ψ, tmle_result, one_step_result, Ψ̂ᵢ), cache -end - -""" - tmle!(cache::TMLECache, Ψ::Estimand; verbosity=1, threshold=1e-8) - -Updates the TMLECache with Ψ and runs TMLE. - -$TMLE_ARGS_DOCS -""" -function tmle!(cache::TMLECache, Ψ::Estimand; kwargs...) - update!(cache, Ψ) - tmle!(cache; kwargs...) -end - -""" - tmle!(cache::TMLECache, η_spec::NuisanceSpec; kwargs...) - -Updates the TMLECache with η_spec and runs TMLE. - -$TMLE_ARGS_DOCS -""" -function tmle!(cache::TMLECache, η_spec::NuisanceSpec; kwargs...) - update!(cache, η_spec) - tmle!(cache; kwargs...) -end - -""" - tmle!(cache::TMLECache, Ψ::Estimand, η_spec::NuisanceSpec; kwargs...) - -Updates the TMLECache with η_spec and Ψ and runs TMLE. - -$TMLE_ARGS_DOCS -""" -function tmle!(cache::TMLECache, Ψ::Estimand, η_spec::NuisanceSpec; kwargs...) - update!(cache, Ψ, η_spec) - tmle!(cache; kwargs...) -end - -""" - tmle!(cache::TMLECache, η_spec::NuisanceSpec, Ψ::Estimand; kwargs...) - -Updates the TMLECache with η_spec and Ψ and runs TMLE. - -$TMLE_ARGS_DOCS -""" -function tmle!(cache::TMLECache, η_spec::NuisanceSpec, Ψ::Estimand; kwargs...) - update!(cache, Ψ, η_spec) - tmle!(cache; kwargs...) -end - - -""" - fit_nuisance!(cache::TMLECache; verbosity=1, mach_cache=false) - -Fits the nuisance estimands η on the dataset using the specifications from η_spec -and the variables defined by Ψ. -""" -function fit_nuisance!(cache::TMLECache; verbosity=1) - Ψ, η_spec, η = cache.Ψ, cache.η_spec, cache.η - # Fitting P(T|W) - # Only rows with missing values in either W or Tₜ are removed - if η.G === nothing - log_fit(verbosity, "P(T|W)") - nomissing_WT = nomissing(cache.data[:source], treatment_and_confounders(Ψ)) - W = confounders(nomissing_WT, Ψ) - jointT = joint_treatment(treatments(nomissing_WT, Ψ)) - mach = machine(η_spec.G, W, jointT, cache=η_spec.cache) - t = time() - MLJBase.fit!(mach, verbosity=verbosity-1) - verbosity >= 2 && @info string("Time to fit P(T|W): ", time() - t, " s.") - η.G = mach - cache.data[:jointT_levels] = levels(jointT) - else - log_no_fit(verbosity, "P(T|W)") - end - - if η.Q === nothing - # Fitting the Treatment Encoder - if η.H === nothing - log_fit(verbosity, "Encoder") - mach = machine(η_spec.H, treatments(cache.data[:source], Ψ), cache=η_spec.cache) - MLJBase.fit!(mach, verbosity=verbosity-1) - η.H = mach - else - log_no_fit(verbosity, "Encoder") - end - # Data - X = Qinputs(η.H, cache.data[:no_missing], Ψ) - y = outcome(cache.data[:no_missing], Ψ) - # Fitting E[Y|X] - log_fit(verbosity, "E[Y|X]") - mach = machine(η_spec.Q, X, y, cache=η_spec.cache) - t = time() - MLJBase.fit!(mach, verbosity=verbosity-1) - verbosity >= 2 && @info string("Time to fit E[Y|X]: ", time() - t, " s.") - η.Q = mach - cache.data[:Q₀] = MLJBase.predict(mach, X) - else - log_no_fit(verbosity, "Encoder") - log_no_fit(verbosity, "E[Y|X]") - end -end - - -function tmle_step!(cache::TMLECache; verbosity=1, threshold=1e-8, weighted_fluctuation=false) - # Fit fluctuation - offset = TMLE.compute_offset(cache.data[:Q₀]) - W = TMLE.confounders(cache.data[:no_missing], cache.Ψ) - jointT = TMLE.joint_treatment(TMLE.treatments(cache.data[:no_missing], cache.Ψ)) + TMLEResult(Ψ, tmle_result, one_step_result, Ψ̂ᵢ), fluctuation_mach +end + +function tmle_step(Ψ::CMCompositeEstimand; verbosity=1, threshold=1e-8, weighted_fluctuation=false) + outcome_equation = Ψ.scm[outcome(Ψ)] + X, y = outcome_equation.mach.data + outcome_mach = outcome_equation.mach + # Compute offset + offset = TMLE.compute_offset(MLJBase.predict(outcome_mach)) + # Compute clever covariate and weights for weighted fluctuation mode + W = confounders(X, Ψ) + T = treatments(X, Ψ) covariate, weights = TMLE.clever_covariate_and_weights( - jointT, W, cache.η.G, cache.data[:indicators_str]; + Ψ.scm, T, W, indicator_fns(Ψ); threshold=threshold, weighted_fluctuation=weighted_fluctuation ) - X = TMLE.fluctuation_input(covariate, offset) - y = TMLE.outcome(cache.data[:no_missing], cache.Ψ) - mach = machine(cache.η_spec.F, X, y, weights, cache=cache.η_spec.cache) + # Fit fluctuation + Xfluct = TMLE.fluctuation_input(covariate, offset) + mach = machine(F_model(scitype(y)), Xfluct, y, weights, cache=true) MLJBase.fit!(mach, verbosity=verbosity-1) - # Update cache - cache.η.F = mach - # This is useful for gradient_Y_X - cache.data[:covariate] = X.covariate .* weights - cache.data[:Qfluct] = MLJBase.predict(mach, X) + return mach end -function counterfactual_aggregates(cache::TMLECache; threshold=1e-8, weighted_fluctuation=false) - dataset = cache.data[:no_missing] - WC = TMLE.confounders_and_covariates(dataset, cache.Ψ) - Ttemplate = TMLE.treatments(dataset, cache.Ψ) +function counterfactual_aggregates(Ψ::CMCompositeEstimand, fluctuation_mach; threshold=1e-8, weighted_fluctuation=false) + outcome_eq = outcome_equation(Ψ) + X = outcome_eq.mach.data[1] + Ttemplate = treatments(X, Ψ) n = nrows(Ttemplate) counterfactual_aggregateᵢ = zeros(n) counterfactual_aggregate = zeros(n) # Loop over Treatment settings - for (vals, sign) in cache.data[:indicators_tuple] + indicators = indicator_fns(Ψ) + for (vals, sign) in indicators # Counterfactual dataset for a given treatment setting - Tc = TMLE.counterfactualTreatment(vals, Ttemplate) - Xc = Qinputs(cache.η.H, merge(WC, Tc), cache.Ψ) + T_ct = TMLE.counterfactualTreatment(vals, Ttemplate) + X_ct = selectcols(merge(X, T_ct), parents(outcome_eq)) # Counterfactual predictions with the initial Q - ŷᵢ = MLJBase.predict(cache.η.Q, Xc) + ŷᵢ = MLJBase.predict(outcome_eq.mach, X_ct) counterfactual_aggregateᵢ .+= sign .* expected_value(ŷᵢ) # Counterfactual predictions with F offset = compute_offset(ŷᵢ) - jointT = categorical(repeat([joint_name(vals)], n), levels=cache.data[:jointT_levels]) covariate, _ = clever_covariate_and_weights( - jointT, confounders(WC, cache.Ψ), cache.η.G, cache.data[:indicators_str]; + Ψ.scm, T_ct, confounders(X, Ψ), indicators; threshold=threshold, weighted_fluctuation=weighted_fluctuation ) - Xfluct = fluctuation_input(covariate, offset) - ŷ = predict(cache.η.F, Xfluct) + Xfluct_ct = fluctuation_input(covariate, offset) + ŷ = predict(fluctuation_mach, Xfluct_ct) counterfactual_aggregate .+= sign .* expected_value(ŷ) end return counterfactual_aggregate, counterfactual_aggregateᵢ @@ -323,19 +104,24 @@ gradient_W(counterfactual_aggregate, estimate) = This part of the gradient is evaluated on the original dataset. All quantities have been precomputed and cached. """ -function gradients_Y_X(cache::TMLECache) - covariate = cache.data[:covariate] - y = float(outcome(cache.data[:no_missing], cache.Ψ)) - gradient_Y_Xᵢ = covariate .* (y .- expected_value(cache.data[:Q₀])) - gradient_Y_X_fluct = covariate .* (y .- expected_value(cache.data[:Qfluct])) +function gradients_Y_X(outcome_mach::Machine, fluctuation_mach::Machine) + X, y = fluctuation_mach.data + y = float(y) + gradient_Y_Xᵢ = X.covariate .* (y .- expected_value(MLJBase.predict(outcome_mach))) + gradient_Y_X_fluct = X.covariate .* (y .- expected_value(MLJBase.predict(fluctuation_mach))) return gradient_Y_Xᵢ, gradient_Y_X_fluct end -function gradient_and_estimates(cache::TMLECache; threshold=1e-8, weighted_fluctuation=false) - counterfactual_aggregate, counterfactual_aggregateᵢ = TMLE.counterfactual_aggregates(cache; threshold=threshold, weighted_fluctuation=weighted_fluctuation) +function gradient_and_estimates(Ψ::CMCompositeEstimand, fluctuation_mach::Machine; threshold=1e-8, weighted_fluctuation=false) + counterfactual_aggregate, counterfactual_aggregateᵢ = TMLE.counterfactual_aggregates( + Ψ, + fluctuation_mach; + threshold=threshold, + weighted_fluctuation=weighted_fluctuation + ) Ψ̂, Ψ̂ᵢ = mean(counterfactual_aggregate), mean(counterfactual_aggregateᵢ) - gradient_Y_Xᵢ, gradient_Y_X_fluct = gradients_Y_X(cache) + gradient_Y_Xᵢ, gradient_Y_X_fluct = gradients_Y_X(outcome_equation(Ψ).mach, fluctuation_mach) IC = gradient_Y_X_fluct .+ gradient_W(counterfactual_aggregate, Ψ̂) ICᵢ = gradient_Y_Xᵢ .+ gradient_W(counterfactual_aggregateᵢ, Ψ̂ᵢ) return IC, Ψ̂, ICᵢ, Ψ̂ᵢ diff --git a/src/estimands.jl b/src/estimands.jl index 5c431010..7e0c479d 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -149,59 +149,6 @@ IATE₁ = IATE( end -##################################################################### -### Nuisance Estimands ### -##################################################################### - -""" -# NuisanceEstimands - -The set of estimators that need to be estimated but are not of direct interest. - -# Causal graph: - -$causal_graph - -# Fields: - -All fields are MLJBase.Machine. - - - Q: An estimator of E[Y|X] - - G: An estimator of P(T|W) - - H: A one-hot-encoder categorical treatments - - F: A generalized linear model to fluctuate E[Y|X] -""" -mutable struct NuisanceEstimands - Q::Union{Nothing, MLJBase.Machine} - G::Union{Nothing, MLJBase.Machine} - H::Union{Nothing, MLJBase.Machine} - F::Union{Nothing, MLJBase.Machine} -end - -struct NuisanceSpec - Q::MLJBase.Model - G::MLJBase.Model - H::MLJBase.Model - F::MLJBase.Model - cache::Bool -end - -""" - NuisanceSpec(Q, G; H=encoder(), F=Q_model(target_scitype(Q))) - -Specification of the nuisance estimands to be learnt. - -# Arguments: - -- Q: For the estimation of E₀[Y|T=case, X] -- G: For the estimation of P₀(T|W) -- H: The `TreatmentTransformer`` to deal with categorical treatments -- F: The generalized linear model used to fluctuate the initial Q -- cache: Whether corresponding machines will cache data or not. -""" -NuisanceSpec(Q, G; H=TreatmentTransformer(), F=F_model(target_scitype(Q)), cache=true) = - NuisanceSpec(Q, G, H, F, cache) - ##################################################################### ### Methods ### ##################################################################### @@ -210,7 +157,9 @@ CMCompositeEstimand = Union{CM, ATE, IATE} Base.show(io::IO, Ψ::T) where T <: CMCompositeEstimand = println(io, T) -equations_to_fit(Ψ::CMCompositeEstimand) = (Ψ.scm[outcome(Ψ)], (Ψ.scm[t] for t in treatments(Ψ))...) +equations_to_fit(Ψ::CMCompositeEstimand) = (outcome_equation(Ψ), (Ψ.scm[t] for t in treatments(Ψ))...) + +outcome_equation(Ψ::CMCompositeEstimand) = Ψ.scm[outcome(Ψ)] variable_not_indataset(role, variable) = string(role, " variable: ", variable, " is not in the dataset.") @@ -242,46 +191,35 @@ function isidentified(Ψ::CMCompositeEstimand, dataset) return length(reasons) == 0, reasons end -outcome(Ψ::CMCompositeEstimand) = Ψ.outcome - -selectcols(data, cols) = data |> TableOperations.select(cols...) |> Tables.columntable +confounders(scm, variable, outcome) = + intersect(parents(scm, outcome), parents(scm, variable)) function confounders(Ψ::CMCompositeEstimand) + confounders_ = [] + treatments_ = Tuple(treatments(Ψ)) + for treatment in treatments_ + push!(confounders_, confounders(Ψ.scm, treatment, outcome(Ψ))) + end + return NamedTuple{treatments_}(confounders_) +end + +function confounders(dataset, Ψ::CMCompositeEstimand) confounders_ = [] treatments_ = Tuple(treatments(Ψ)) for treatment in treatments_ push!( confounders_, - intersect(parents(Ψ.scm, outcome(Ψ)), parents(Ψ.scm, treatment)) + selectcols(dataset, confounders(Ψ.scm, treatment, outcome(Ψ))) ) end return NamedTuple{treatments_}(confounders_) end -confounders(dataset, Ψ) = selectcols(dataset, confounders(Ψ)) - - treatments(Ψ::CMCompositeEstimand) = collect(keys(Ψ.treatment)) -treatments(dataset, Ψ) = selectcols(dataset, treatments(Ψ)) - -outcome(Ψ::Estimand) = Ψ.outcome -outcome(dataset, Ψ) = Tables.getcolumn(dataset, outcome(Ψ)) +treatments(dataset, Ψ::CMCompositeEstimand) = selectcols(dataset, treatments(Ψ)) -treatment_and_confounders(Ψ::Estimand) = vcat(confounders(Ψ), treatments(Ψ)) - -confounders_and_covariates(Ψ::Estimand) = vcat(confounders(Ψ), covariates(Ψ)) -confounders_and_covariates(dataset, Ψ) = selectcols(dataset, confounders_and_covariates(Ψ)) - -""" -Merges together confounders, covariates and floating point representation of treatments. -""" -Qinputs(H, dataset, Ψ::Estimand) = merge( - columntable(confounders_and_covariates(dataset, Ψ)), - MLJBase.transform(H, treatments(dataset, Ψ)) - ) - - -allcolumns(Ψ::Estimand) = vcat(confounders_and_covariates(Ψ), treatments(Ψ), outcome(Ψ)) +outcome(Ψ::CMCompositeEstimand) = Ψ.outcome +outcome(dataset, Ψ::CMCompositeEstimand) = Tables.getcolumn(dataset, outcome(Ψ)) F_model(::Type{<:AbstractVector{<:MLJBase.Continuous}}) = LinearRegressor(fit_intercept=false, offsetcol = :offset) @@ -296,7 +234,7 @@ namedtuples_from_dicts(d::Dict) = NamedTuple{Tuple(keys(d))}([namedtuples_from_dicts(val) for val in values(d)]) -function param_key(Ψ::Estimand) +function param_key(Ψ::CMCompositeEstimand) return ( join(Ψ.confounders, "_"), join(keys(Ψ.treatment), "_"), diff --git a/src/utils.jl b/src/utils.jl index 90e7e18c..830d1559 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,6 +2,8 @@ ## General Utilities ############################################################################### +selectcols(data, cols) = data |> TableOperations.select(cols...) |> Tables.columntable + function logit!(v) for i in eachindex(v) v[i] = logit(v[i]) @@ -94,7 +96,7 @@ compute_offset(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = exp compute_offset(ŷ::AbstractVector{<:Real}) = expected_value(ŷ) compute_offset(Ψ::CMCompositeEstimand) = - compute_offset(MLJBase.predict(Ψ.scm[outcome(Ψ)].mach)) + compute_offset(MLJBase.predict(outcome_equation(Ψ).mach)) function balancing_weights(scm, W, T; threshold=1e-8) density = ones(nrows(T)) diff --git a/test/estimands.jl b/test/estimands.jl index 478fc377..76077064 100644 --- a/test/estimands.jl +++ b/test/estimands.jl @@ -8,7 +8,7 @@ using StableRNGs using MLJBase using MLJGLMInterface -@testset "Test methods" begin +@testset "Test various methods" begin n = 5 dataset = ( Y = rand(n), @@ -28,8 +28,10 @@ using MLJGLMInterface ) @test TMLE.treatments(Ψ) == [:T] + @test TMLE.treatments(dataset, Ψ) == (T=dataset.T,) @test TMLE.outcome(Ψ) == :Y @test TMLE.confounders(Ψ) == (T=[:W₁, :W₂],) + @test TMLE.confounders(dataset, Ψ) == (T=(W₁=dataset.W₁, W₂=dataset.W₂),) identified, reasons = isidentified(Ψ, dataset) @test identified === true @test reasons == [] diff --git a/test/non_regression_test.jl b/test/non_regression_test.jl index 9e2dacb0..5f28f130 100644 --- a/test/non_regression_test.jl +++ b/test/non_regression_test.jl @@ -9,25 +9,23 @@ using MLJGLMInterface @testset "Test ATE on perinatal dataset." begin # This is a non-regression test which was checked against the R tmle3 package - data = CSV.read(joinpath("data", "perinatal.csv"), DataFrame, missingstring=["", "NA"]) + dataset = CSV.read(joinpath("data", "perinatal.csv"), DataFrame, missingstring=["", "NA"]) confounders = [:apgar1, :apgar5, :gagebrth, :mage, :meducyrs, :sexn] - data.haz01 = categorical(data.haz01) - data.parity01 = categorical(data.parity01) + dataset.haz01 = categorical(dataset.haz01) + dataset.parity01 = categorical(dataset.parity01) for col in confounders - data[!, col] = float(data[!, col]) + dataset[!, col] = float(dataset[!, col]) end - Ψ = ATE( - outcome=:haz01, - treatment = (parity01=(case=1, control=0),), - confounders = confounders + scm = StaticConfoundedModel( + :haz01, :parity01, confounders, + outcome_model = TreatmentTransformer() |> LinearBinaryClassifier(), + treatment_model = LinearBinaryClassifier() ) - η_spec = NuisanceSpec( - LinearBinaryClassifier(), - LinearBinaryClassifier() - ) - r, cache = tmle(Ψ, η_spec, data, threshold=0.025, verbosity=0) - l, u = confint(OneSampleTTest(r.tmle)) - @test TMLE.estimate(r.tmle) ≈ -0.185533 atol = 1e-6 + Ψ = ATE(scm=scm, outcome=:haz01, treatment=(parity01=(case=1, control=0),)) + + result, fluctuation_mach = tmle(Ψ, dataset, threshold=0.025, verbosity=0) + l, u = confint(OneSampleTTest(result.tmle)) + @test TMLE.estimate(result.tmle) ≈ -0.185533 atol = 1e-6 @test l ≈ -0.279246 atol = 1e-6 @test u ≈ -0.091821 atol = 1e-6 end diff --git a/test/utils.jl b/test/utils.jl index 1ea1aecf..543c8ddf 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -33,7 +33,7 @@ using MLJModels ) indicator_fns = TMLE.indicator_fns(Ψ) @test indicator_fns == Dict(("A", 1) => 1.) - indic_values = TMLE.indicator_values(indicator_fns, treatments(dataset, Ψ)) + indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) @test indic_values == [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0] # ATE Ψ = ATE( @@ -46,7 +46,7 @@ using MLJModels ("A", 1) => 1.0, ("B", 0) => -1.0 ) - indic_values = TMLE.indicator_values(indicator_fns, treatments(dataset, Ψ)) + indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) @test indic_values == [0.0, -1.0, 1.0, 0.0, 0.0, -1.0, 1.0, 0.0] # 2-points IATE Ψ = IATE( @@ -61,7 +61,7 @@ using MLJModels ("B", 1) => -1.0, ("B", 0) => 1.0 ) - indic_values = TMLE.indicator_values(indicator_fns, treatments(dataset, Ψ)) + indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) @test indic_values == [-1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0] # 3-points IATE Ψ = IATE( @@ -80,7 +80,7 @@ using MLJModels ("B", 1, "D") => 1.0, ("A", 0, "C") => -1.0 ) - indic_values = TMLE.indicator_values(indicator_fns, treatments(dataset, Ψ)) + indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) @test indic_values == [-1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0] end @@ -113,7 +113,6 @@ end @test expectation == ŷ end - @testset "Test counterfactualTreatment" begin vals = (true, "a") T = ( @@ -129,7 +128,7 @@ end @test !isordered(cfT.T₂) end -@testset "Test compute_offset, clever_covariate_and_weights and tmle_step" begin +@testset "Test Targeting sequence with 1 Treatment" begin ### In all cases, models are "constant" models ## First case: 1 Treamtent variable scm = StaticConfoundedModel( @@ -151,21 +150,21 @@ end ) fit!(scm, dataset, verbosity=0) indicator_fns = TMLE.indicator_fns(Ψ) - T = treatments(dataset, Ψ) + T = TMLE.treatments(dataset, Ψ) @test T == (T=dataset.T,) - W = confounders(dataset, Ψ) + W = TMLE.confounders(dataset, Ψ) @test W == (T=(W₁=dataset.W₁, W₂=dataset.W₂, W₃=dataset.W₃),) @test TMLE.compute_offset(Ψ) == repeat([mean(dataset.Y)], 7) - + weighted_fluctuation = true bw = TMLE.balancing_weights(scm, W, T) - cov, w = TMLE.clever_covariate_and_weights(scm, T, W, indicator_fns, weighted_fluctuation=true) + cov, w = TMLE.clever_covariate_and_weights(scm, T, W, indicator_fns, weighted_fluctuation=weighted_fluctuation) @test cov == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] @test w == bw == [1.75, 3.5, 7.0, 1.75, 1.75, 3.5, 1.75] - weighted_mach = TMLE.tmle_step(Ψ, weighted_fluctuation=true) - @test fitted_params(weighted_mach) == ( + mach = TMLE.tmle_step(Ψ, weighted_fluctuation=weighted_fluctuation) + @test fitted_params(mach) == ( features = [:covariate], coef = [0.12500000000000003], intercept = 0.0 @@ -180,8 +179,9 @@ end coef = [0.047619047619047616], intercept = 0.0 ) +end - ## Second case: 2 Treatment variables +@testset "Test Targeting sequence with 2 Treatments" begin scm = StaticConfoundedModel( [:Y], [:T₁, :T₂], [:W₁, :W₂, :W₃], treatment_model = ConstantClassifier(), @@ -222,7 +222,9 @@ end coef = [-0.0945122832139119], intercept = 0.0 ) +end +@testset "Test Targeting sequence with 3 Treatments" begin ## Third case: 3 Treatment variables scm = StaticConfoundedModel( [:Y], [:T₁, :T₂, :T₃], [:W], @@ -270,28 +272,6 @@ end ) end -@testset "Test compute_offset" begin - n = 10 - X = rand(n, 3) - - # When Y is binary - y = categorical([1, 1, 1, 1, 0, 0, 0, 0, 0, 0]) - mach = machine(ConstantClassifier(), MLJBase.table(X), y) - fit!(mach, verbosity=0) - ŷ = MLJBase.predict(mach) - # Should be equal to logit(Ê[Y|X])= logit(4/10) = -0.4054651081081643 - @test TMLE.compute_offset(ŷ) == repeat([-0.4054651081081643], n) - - # When Y is continuous - y = [1., 2., 3, 4, 5, 6, 7, 8, 9, 10] - mach = machine(MLJModels.DeterministicConstantRegressor(), MLJBase.table(X), y) - fit!(mach, verbosity=0) - ŷ = predict(mach) - # Should be equal to Ê[Y|X] = 5.5 - @test TMLE.compute_offset(ŷ) == repeat([5.5], n) - -end - @testset "Test fluctuation_input" begin X = TMLE.fluctuation_input([1., 2.], [1., 2]) @test X.covariate isa Vector{Float64} From 37e1486e046f08bb935f74a3e98aa761c505e374 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 4 Aug 2023 16:15:15 +0100 Subject: [PATCH 014/151] more updates --- src/utils.jl | 14 +----------- test/3points_interactions.jl | 44 ++++++++++++++++++++---------------- test/helper_fns.jl | 15 ++++++------ 3 files changed, 33 insertions(+), 40 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 830d1559..6a8748e9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -16,18 +16,6 @@ function plateau!(v::AbstractVector, threshold) end end -joint_name(it) = join(it, "_&_") - -joint_treatment(T) = - categorical(joint_name.(Tables.rows(T))) - - -log_fit(verbosity, model) = - verbosity >= 1 && @info string("→ Fitting ", model) - -log_no_fit(verbosity, model) = - verbosity >= 1 && @info string("→ Reusing previous ", model) - function nomissing(table) sch = Tables.schema(table) for type in sch.types @@ -98,7 +86,7 @@ compute_offset(ŷ::AbstractVector{<:Real}) = expected_value(ŷ) compute_offset(Ψ::CMCompositeEstimand) = compute_offset(MLJBase.predict(outcome_equation(Ψ).mach)) -function balancing_weights(scm, W, T; threshold=1e-8) +function balancing_weights(scm::SCM, W, T; threshold=1e-8) density = ones(nrows(T)) for colname ∈ Tables.columnnames(T) mach = scm[colname].mach diff --git a/test/3points_interactions.jl b/test/3points_interactions.jl index 12c052a7..295add55 100644 --- a/test/3points_interactions.jl +++ b/test/3points_interactions.jl @@ -12,39 +12,43 @@ using LogExpFunctions include("helper_fns.jl") -interacter = InteractionTransformer(order=3) |> LinearRegressor() - -function make_dataset(;n=1000) +function dataset_scm_and_truth(;n=1000) rng = StableRNG(123) W = rand(rng, Uniform(), n) T₁ = rand(rng, Uniform(), n) .< W T₂ = rand(rng, Uniform(), n) .< logistic.(1 .- 2W) T₃ = rand(rng, [0, 1], n) - y = 2 .- 2T₁.*T₂.*T₃.*(W .+ 10) + rand(rng, Normal(0, 0.03), n) - return (W=W, T₁=categorical(T₁), T₂=categorical(T₂), T₃=categorical(T₃), y=y), -21 + Y = 2 .- 2T₁.*T₂.*T₃.*(W .+ 10) + rand(rng, Normal(0, 0.03), n) + dataset = (W=W, T₁=categorical(T₁), T₂=categorical(T₂), T₃=categorical(T₃), Y=Y) + scm = StaticConfoundedModel( + [:Y], + [:T₁, :T₂, :T₃], + :W, + outcome_model = TreatmentTransformer() |> InteractionTransformer(order=3) |> LinearRegressor(), + treatment_model = LogisticClassifier(lambda=0) + ) + truth = -21 + return dataset, scm, truth end @testset "Test 3-points interactions" begin - dataset, Ψ₀ = make_dataset(;n=1000) + dataset, scm, Ψ₀ = dataset_scm_and_truth(;n=1000) Ψ = IATE( - outcome = :y, - confounders = [:W], - treatment = (T₁=(case=true, control=false), T₂=(case=true, control=false), T₃=(case=1, control=0)) - ) - η_spec = NuisanceSpec( - interacter, - LogisticClassifier(lambda=0), + scm = scm, + outcome = :Y, + treatment = (T₁=(case=true, control=false), T₂=(case=true, control=false), T₃=(case=1, control=0)) ) - tmle_result, cache = tmle(Ψ, η_spec, dataset, verbosity=0) - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache, outcome_name=:y) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - tmle_result, _ = tmle!(cache, verbosity=0, weighted_fluctuation=true) - test_coverage(tmle_result, Ψ₀) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) + result, fluctuation = tmle(Ψ, dataset, verbosity=0) + test_coverage(result, Ψ₀) + test_fluct_decreases_risk(Ψ, fluctuation) + test_mean_inf_curve_almost_zero(result; atol=1e-10) + + result, _ = tmle(Ψ, dataset, verbosity=0, weighted_fluctuation=true) + test_coverage(result, Ψ₀) + test_mean_inf_curve_almost_zero(result; atol=1e-10) end diff --git a/test/helper_fns.jl b/test/helper_fns.jl index 29ab0b15..e622cefc 100644 --- a/test/helper_fns.jl +++ b/test/helper_fns.jl @@ -20,10 +20,11 @@ The fluctuation is supposed to decrease the risk as its objective function is th It seems that sometimes this is not entirely true in practice, so the test actually checks that it does not increase risk more than tol """ -function test_fluct_decreases_risk(cache; outcome_name::Symbol=nothing, atol=1e-6) - y = cache.data[:no_missing][outcome_name] - initial_risk = risk(cache.data[:Q₀], y) - fluct_risk = risk(cache.data[:Qfluct], y) +function test_fluct_decreases_risk(Ψ::TMLE.CMCompositeEstimand, fluctuation_mach; atol=1e-6) + outcome_mach = TMLE.outcome_equation(Ψ).mach + y = outcome_mach.data[2] + initial_risk = risk(MLJBase.predict(outcome_mach), y) + fluct_risk = risk(MLJBase.predict(fluctuation_mach), y) @test initial_risk >= fluct_risk || isapprox(initial_risk, fluct_risk, atol=atol) end @@ -34,12 +35,12 @@ end Both the TMLE and OneStep estimators are suppose to be asymptotically efficient and cover the truth at the given confidence level: here 0.05 """ -function test_coverage(tmle_result::TMLE.TMLEResult, Ψ₀) +function test_coverage(result::TMLE.TMLEResult, Ψ₀) # TMLE - lb, ub = confint(OneSampleTTest(tmle_result.tmle)) + lb, ub = confint(OneSampleTTest(result.tmle)) @test lb ≤ Ψ₀ ≤ ub # OneStep - lb, ub = confint(OneSampleTTest(tmle_result.onestep)) + lb, ub = confint(OneSampleTTest(result.onestep)) @test lb ≤ Ψ₀ ≤ ub end From af22035ae97fd3142e0176456ce58716a808bfbc Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 4 Aug 2023 16:21:13 +0100 Subject: [PATCH 015/151] fix weighted fluctuation --- src/cache.jl | 7 ++++--- test/3points_interactions.jl | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/cache.jl b/src/cache.jl index 5222aefe..8c62954d 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -105,10 +105,11 @@ gradient_W(counterfactual_aggregate, estimate) = This part of the gradient is evaluated on the original dataset. All quantities have been precomputed and cached. """ function gradients_Y_X(outcome_mach::Machine, fluctuation_mach::Machine) - X, y = fluctuation_mach.data + X, y, w = fluctuation_mach.data y = float(y) - gradient_Y_Xᵢ = X.covariate .* (y .- expected_value(MLJBase.predict(outcome_mach))) - gradient_Y_X_fluct = X.covariate .* (y .- expected_value(MLJBase.predict(fluctuation_mach))) + covariate = X.covariate .* w + gradient_Y_Xᵢ = covariate .* (y .- expected_value(MLJBase.predict(outcome_mach))) + gradient_Y_X_fluct = covariate .* (y .- expected_value(MLJBase.predict(fluctuation_mach))) return gradient_Y_Xᵢ, gradient_Y_X_fluct end diff --git a/test/3points_interactions.jl b/test/3points_interactions.jl index 295add55..6068805e 100644 --- a/test/3points_interactions.jl +++ b/test/3points_interactions.jl @@ -46,7 +46,7 @@ end test_fluct_decreases_risk(Ψ, fluctuation) test_mean_inf_curve_almost_zero(result; atol=1e-10) - result, _ = tmle(Ψ, dataset, verbosity=0, weighted_fluctuation=true) + result, fluctuation = tmle(Ψ, dataset, verbosity=0, weighted_fluctuation=true) test_coverage(result, Ψ₀) test_mean_inf_curve_almost_zero(result; atol=1e-10) From 8dcfffa4fc0af5f160d91fb3c2b926b49f50c1fe Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 4 Aug 2023 19:58:50 +0100 Subject: [PATCH 016/151] fix some tests --- README.md | 9 +- src/cache.jl | 33 ++++- src/scm.jl | 10 +- test/3points_interactions.jl | 1 - test/double_robustness_ate.jl | 231 +++++++++++++++++----------------- 5 files changed, 159 insertions(+), 125 deletions(-) diff --git a/README.md b/README.md index 4cd969be..b6088e0c 100644 --- a/README.md +++ b/README.md @@ -7,5 +7,12 @@ Targeted Minimum Loss-Based Estimation (TMLE) is a framework for efficient estimation of pathwise differentiable estimands. The goal of this package is to provide an implementation on top of the [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) framework. - **New to TMLE.jl?** The API design is still work in progress but you can already [get started now](https://targene.github.io/TMLE.jl/stable/) + +## TODO + +- Check parameter definition +- Fix tests +- Simplify code +- More API +- Documentation update diff --git a/src/cache.jl b/src/cache.jl index 8c62954d..2183b5c2 100644 --- a/src/cache.jl +++ b/src/cache.jl @@ -1,29 +1,52 @@ +NotIdentifiedError(reasons) = ArgumentError(string( + "The estimand is not identified for the following reasons: \n", join(reasons, "\n"))) + +""" + check_treatment_settings(settings::NamedTuple, levels, treatment_name) + +Checks the case/control values defining the treatment contrast are present in the dataset levels. + +Note: This method is for estimands like the ATE or IATE that have case/control treatment settings represented as +`NamedTuple`. +""" function check_treatment_settings(settings::NamedTuple, levels, treatment_name) for (key, val) in zip(keys(settings), settings) - any(string(val) .== levels) || + any(val .== levels) || throw(ArgumentError(string( - "The '", key, "' string representation: '", val, "' for treatment ", treatment_name, + "The treatment variable ", treatment_name, "'s, '", key, "' value: '", val, " in Ψ does not match any level of the corresponding variable in the dataset: ", string.(levels)))) end end +""" + check_treatment_settings(setting, levels, treatment_name) + +Checks the value defining the treatment setting is present in the dataset levels. + +Note: This is for estimands like the CM that do not have case/control treatment settings +and are represented as simple values. +""" function check_treatment_settings(setting, levels, treatment_name) - any(string(setting) .== levels) || + any(setting .== levels) || throw(ArgumentError(string( "The string representation: '", val, "' for treatment ", treatment_name, " in Ψ does not match any level of the corresponding variable in the dataset: ", string.(levels)))) end -function check_treatment_values(cache, Ψ::Estimand) +function check_estimand(Ψ::Estimand, dataset) + identified, reasons = isidentified(Ψ, dataset) + identified || throw(NotIdentifiedError(reasons)) for treatment_name in treatments(Ψ) - treatment_levels = string.(levels(Tables.getcolumn(cache.data[:source], treatment_name))) + treatment_levels = levels(Tables.getcolumn(dataset, treatment_name)) treatment_settings = getproperty(Ψ.treatment, treatment_name) check_treatment_settings(treatment_settings, treatment_levels, treatment_name) end end function tmle(Ψ::Estimand, dataset; verbosity=1, cache=true, force=false, threshold=1e-8, weighted_fluctuation=false) + # Check the estimand against the dataset + check_estimand(Ψ, dataset) # Initial fit of the SCM's equations verbosity >= 1 && @info "Fitting the required equations..." fit!(Ψ, dataset; verbosity=verbosity, cache=cache, force=force) diff --git a/src/scm.jl b/src/scm.jl index f5003aa8..958e0f02 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -21,7 +21,8 @@ StructuralEquation(outcome, parents; model=nothing) = StructuralEquation(outcome const SE = StructuralEquation function MLJBase.fit!(eq::SE, dataset; verbosity=1, cache=true, force=false) - if eq.mach === nothing || force === true + # Fit if never fitted or if new model + if eq.mach === nothing || eq.model != eq.mach.model verbosity >= 1 && @info(string("Fitting Structural Equation corresponding to variable ", outcome(eq), ".")) data = nomissing(dataset, vcat(parents(eq), outcome(eq))) X = selectcols(data, parents(eq)) @@ -29,12 +30,9 @@ function MLJBase.fit!(eq::SE, dataset; verbosity=1, cache=true, force=false) mach = machine(eq.model, X, y, cache=cache) MLJBase.fit!(mach, verbosity=verbosity-1) eq.mach = mach + # Otherwise only fit if force is true else - verbosity >= 0 && @info(string( - "Structural Equation corresponding to variable ", - outcome(eq), - " already fitted, skipping. Set `force=true` to force refit." - )) + MLJBase.fit!(eq.mach, verbosity=verbosity-1, force=force) end end diff --git a/test/3points_interactions.jl b/test/3points_interactions.jl index 6068805e..b9b850f7 100644 --- a/test/3points_interactions.jl +++ b/test/3points_interactions.jl @@ -40,7 +40,6 @@ end treatment = (T₁=(case=true, control=false), T₂=(case=true, control=false), T₃=(case=1, control=0)) ) - result, fluctuation = tmle(Ψ, dataset, verbosity=0) test_coverage(result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation) diff --git a/test/double_robustness_ate.jl b/test/double_robustness_ate.jl index 46311fdc..73cf478d 100644 --- a/test/double_robustness_ate.jl +++ b/test/double_robustness_ate.jl @@ -31,13 +31,16 @@ function binary_outcome_binary_treatment_pb(;n=100) # and convert types W = convert(Array{Float64}, w) T = categorical(t) - y = categorical(y) + Y = categorical(y) + dataset = (T=T, W=W, Y=Y) # Compute the theoretical ATE ATE₁ = py_given_aw(1, 1)*p_w() + (1-p_w())*py_given_aw(1, 0) ATE₀ = py_given_aw(0, 1)*p_w() + (1-p_w())*py_given_aw(0, 0) ATE = ATE₁ - ATE₀ + # Define the SCM + scm = StaticConfoundedModel(:Y, :T, :W) - return (T=T, W=W, y=y), ATE + return dataset, scm, ATE end """ @@ -45,36 +48,45 @@ From https://www.degruyter.com/document/doi/10.2202/1557-4679.1043/html The theoretical ATE is 1 """ function continuous_outcome_binary_treatment_pb(;n=100) + # Dataset rng = StableRNG(123) Unif = Uniform(0, 1) W = float(rand(rng, Bernoulli(0.5), n, 3)) W₁, W₂, W₃ = W[:, 1], W[:, 2], W[:, 3] t = rand(rng, Unif, n) .< logistic.(0.5W₁ + 1.5W₂ - W₃) y = 4t + 25W₁ + 3W₂ - 4W₃ + rand(rng, Normal(0, 0.1), n) - # Type coercion T = categorical(t) - return (T = T, W₁ = W₁, W₂ = W₂, W₃ = W₃, y = y), 4 + dataset = (T = T, W₁ = W₁, W₂ = W₂, W₃ = W₃, Y = y) + # Theroretical ATE + ATE = 4 + # Define the SCM + scm = StaticConfoundedModel(:Y, :T, [:W₁, :W₂, :W₃]) + + return dataset, scm, ATE end function continuous_outcome_categorical_treatment_pb(;n=100, control="TT", case="AA") + # Define dataset rng = StableRNG(123) ft(T) = (T .== "AA") - (T .== "AT") + 2(T .== "TT") fw(W₁, W₂, W₃) = 2W₁ + 3W₂ - 4W₃ - W = float(rand(rng, Bernoulli(0.5), n, 3)) W₁, W₂, W₃ = W[:, 1], W[:, 2], W[:, 3] θ = rand(rng, 3, 3) softmax = exp.(W*θ) ./ sum(exp.(W*θ), dims=2) T = [sample(rng, ["TT", "AA", "AT"], Weights(softmax[i, :])) for i in 1:n] y = ft(T) + fw(W₁, W₂, W₃) + rand(rng, Normal(0,1), n) - - # Ew[E[Y|t,w]] = ∑ᵤ (ft(T) + fw(w))p(w) = ft(t) + 0.5 + dataset = (T = categorical(T), W₁ = W₁, W₂ = W₂, W₃ = W₃, Y = y) + # True ATE: Ew[E[Y|t,w]] = ∑ᵤ (ft(T) + fw(w))p(w) = ft(t) + 0.5 ATE = (ft(case) + 0.5) - (ft(control) + 0.5) - return (T = categorical(T), W₁ = W₁, W₂ = W₂, W₃ = W₃, y = y), ATE + # Define the SCM + scm = StaticConfoundedModel(:Y, :T, [:W₁, :W₂, :W₃]) + return dataset, scm, ATE end function dataset_2_treatments_pb(;rng = StableRNG(123), n=100) + # Dataset μY(W₁, W₂, T₁, T₂) = 4T₁ .- 2T₂ .+ 5W₁ .- 3W₂ W₁ = rand(rng, Normal(), n) W₂ = rand(rng, Normal(), n) @@ -83,165 +95,160 @@ function dataset_2_treatments_pb(;rng = StableRNG(123), n=100) μT₂ = logistic.(-1.5W₁ + .5W₂ .+ 1) T₂ = float(rand(rng, Uniform(), n) .< μT₂) y = μY(W₁, W₂, T₁, T₂) .+ rand(rng, Normal(), n) - # Those ATEs are MC approximations, only reliable with large samples - case = ones(n) - control = zeros(n) - ATE₁₁₋₀₁ = mean(μY(W₁, W₂, case, case) .- μY(W₁, W₂, control, case)) - ATE₁₁₋₀₀ = mean(μY(W₁, W₂, case, case) .- μY(W₁, W₂, control, control)) - return ( + dataset = ( W₁ = W₁, W₂ = W₂, T₁ = categorical(T₁), T₂ = categorical(T₂), - y = y - ), ATE₁₁₋₀₁, ATE₁₁₋₀₀ -end - -@testset "Test Double Robustness ATE with two treatment variables" begin - dataset, ATE₁₁₋₀₁, ATE₁₁₋₀₀ = dataset_2_treatments_pb(;rng = StableRNG(123), n=10_000) - cache = TMLECache(dataset) - # Test first ATE, only T₁ treatment varies - Ψ = ATE( - outcome = :y, - treatment = (T₁=(case=1., control=0.), T₂=(case=1., control=1.)), - confounders = [:W₁, :W₂] - ) - # When Q is misspecified but G is well specified - η_spec = NuisanceSpec( - MLJModels.DeterministicConstantRegressor(), - LogisticClassifier(lambda=0) + Y = y ) - tmle_result, cache = tmle!(cache, Ψ, η_spec, verbosity=0) - test_coverage(tmle_result, ATE₁₁₋₀₁) - test_fluct_decreases_risk(cache; outcome_name=:y) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result) - # When Q is well specified but G is misspecified - η_spec = NuisanceSpec( - LinearRegressor(), - ConstantClassifier() - ) - tmle_result, cache = tmle!(cache, η_spec, verbosity=0) - test_coverage(tmle_result, ATE₁₁₋₀₁) - test_fluct_decreases_risk(cache; outcome_name=:y) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) + # Those ATEs are MC approximations, only reliable with large samples + case = ones(n) + control = zeros(n) + ATE₁₁₋₀₁ = mean(μY(W₁, W₂, case, case) .- μY(W₁, W₂, control, case)) + ATE₁₁₋₀₀ = mean(μY(W₁, W₂, case, case) .- μY(W₁, W₂, control, control)) + # Define the SCM + scm = StaticConfoundedModel([:Y], [:T₁, :T₂], [:W₁, :W₂]) - # Test second ATE, two treatment varies - Ψ = ATE( - outcome = :y, - treatment = (T₁=(case=1., control=0.), T₂=(case=1., control=0.)), - confounders = [:W₁, :W₂] - ) - tmle_result, cache = tmle!(cache, Ψ, η_spec, verbosity=0) - test_coverage(tmle_result, ATE₁₁₋₀₀) - test_fluct_decreases_risk(cache; outcome_name=:y) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result) - # When Q is well specified but G is misspecified - η_spec = NuisanceSpec( - LinearRegressor(), - ConstantClassifier() - ) - tmle_result, cache = tmle!(cache, η_spec, verbosity=0) - test_coverage(tmle_result, ATE₁₁₋₀₀) - test_fluct_decreases_risk(cache; outcome_name=:y) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) + return dataset, scm , (ATE₁₁₋₀₁, ATE₁₁₋₀₀) end @testset "Test Double Robustness ATE on continuous_outcome_categorical_treatment_pb" begin - dataset, Ψ₀ = continuous_outcome_categorical_treatment_pb(;n=10_000, control="TT", case="AA") + dataset, scm, Ψ₀ = continuous_outcome_categorical_treatment_pb(;n=10_000, control="TT", case="AA") Ψ = ATE( - outcome = :y, - treatment = (T=(case="AA", control="TT"),), - confounders = [:W₁, :W₂, :W₃] + scm = scm, + outcome = :Y, + treatment = (T=(case="AA", control="TT"),), ) # When Q is misspecified but G is well specified - η_spec = NuisanceSpec( - MLJModels.DeterministicConstantRegressor(), - LogisticClassifier(lambda=0) - ) - cache = TMLECache(dataset) - tmle_result, cache = tmle!(cache, Ψ, η_spec, verbosity=0) + scm.Y.model = TreatmentTransformer() |> MLJModels.DeterministicConstantRegressor() + scm.T.model = LogisticClassifier(lambda=0) + + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away @test tmle_result.initial == 0 # When Q is well specified but G is misspecified - η_spec = NuisanceSpec( - LinearRegressor(), - ConstantClassifier() - ) + scm.Y.model = TreatmentTransformer() |> LinearRegressor() + scm.T.model = ConstantClassifier() - tmle_result, cache = tmle!(cache, η_spec, verbosity=0) + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache, outcome_name=:y) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end @testset "Test Double Robustness ATE on binary_outcome_binary_treatment_pb" begin - dataset, Ψ₀ = binary_outcome_binary_treatment_pb(;n=10_000) + dataset, scm, Ψ₀ = binary_outcome_binary_treatment_pb(;n=10_000) Ψ = ATE( - outcome = :y, + scm = scm, + outcome = :Y, treatment = (T=(case=true, control=false),), - confounders = [:W] ) # When Q is misspecified but G is well specified - η_spec = NuisanceSpec( - ConstantClassifier(), - LogisticClassifier(lambda=0) - ) - tmle_result, cache = tmle(Ψ, η_spec, dataset, verbosity=0) + scm.Y.model = TreatmentTransformer() |> ConstantClassifier() + scm.T.model = LogisticClassifier(lambda=0) + + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-6) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away @test tmle_result.initial == 0 # When Q is well specified but G is misspecified - η_spec = NuisanceSpec( - LogisticClassifier(lambda=0), - ConstantClassifier() - ) - tmle_result, cache = tmle!(cache, η_spec, verbosity=0) + scm.Y.model = TreatmentTransformer() |> LogisticClassifier(lambda=0) + scm.T.model = ConstantClassifier() + + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-6) end @testset "Test Double Robustness ATE on continuous_outcome_binary_treatment_pb" begin - dataset, Ψ₀ = continuous_outcome_binary_treatment_pb(n=10_000) + dataset, scm, Ψ₀ = continuous_outcome_binary_treatment_pb(n=10_000) Ψ = ATE( - outcome = :y, + scm=scm, + outcome = :Y, treatment = (T=(case=true, control=false),), - confounders = [:W₁, :W₂, :W₃] ) # When Q is misspecified but G is well specified - η_spec = NuisanceSpec( - MLJModels.DeterministicConstantRegressor(), - LogisticClassifier(lambda=0) - ) + scm.Y.model = TreatmentTransformer() |> MLJModels.DeterministicConstantRegressor() + scm.T.model = LogisticClassifier(lambda=0) - tmle_result, cache = tmle(Ψ, η_spec, dataset, verbosity=0) + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away @test tmle_result.initial == 0 # When Q is well specified but G is misspecified - η_spec = NuisanceSpec( - LinearRegressor(), - ConstantClassifier() - ) - tmle_result, cache = tmle!(cache, η_spec, verbosity=0) + scm.Y.model = TreatmentTransformer() |> LinearRegressor() + scm.T.model = ConstantClassifier() + + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y) + test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) +end + +@testset "Test Double Robustness ATE with two treatment variables" begin + dataset, scm, (ATE₁₁₋₀₁, ATE₁₁₋₀₀) = dataset_2_treatments_pb(;rng = StableRNG(123), n=50_000) + # Test first ATE, only T₁ treatment varies + Ψ = ATE( + scm=scm, + outcome = :Y, + treatment = (T₁=(case=1., control=0.), T₂=(case=1., control=1.)), + ) + # When Q is misspecified but G is well specified + scm.Y.model = TreatmentTransformer() |> MLJModels.DeterministicConstantRegressor() + scm.T₁.model = LogisticClassifier(lambda=0) + scm.T₂.model = LogisticClassifier(lambda=0) + + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) + test_coverage(tmle_result, ATE₁₁₋₀₁) + test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) + test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + # When Q is well specified but G is misspecified + scm.Y.model = TreatmentTransformer() |> LinearRegressor() + scm.T₁.model = ConstantClassifier() + scm.T₂.model = ConstantClassifier() + + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) + test_coverage(tmle_result, ATE₁₁₋₀₁) + test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) + + # Test second ATE, two treatment varies + Ψ = ATE( + scm = scm, + outcome = :Y, + treatment = (T₁=(case=1., control=0.), T₂=(case=1., control=0.)), + ) + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=1) + test_coverage(tmle_result, ATE₁₁₋₀₀) + test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) + test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + + # When Q is well specified but G is misspecified + scm.Y.model = TreatmentTransformer() |> MLJModels.DeterministicConstantRegressor() + scm.T₁.model = LogisticClassifier(lambda=0) + scm.T₂.model = LogisticClassifier(lambda=0) + + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) + test_coverage(tmle_result, ATE₁₁₋₀₀) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end From 6afc9a905d10bd9e43886070abdd47031199f669 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 4 Aug 2023 20:36:26 +0100 Subject: [PATCH 017/151] fix more tests --- test/composition.jl | 61 ++++++++++-------- test/double_robustness_iate.jl | 111 +++++++++++++++++---------------- test/fit_nuisance.jl | 105 ------------------------------- test/missing_management.jl | 24 +++---- test/runtests.jl | 1 - 5 files changed, 105 insertions(+), 197 deletions(-) delete mode 100644 test/fit_nuisance.jl diff --git a/test/composition.jl b/test/composition.jl index 3e4f175c..2f5bf3f8 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -8,24 +8,20 @@ using MLJLinearModels using TMLE using CategoricalArrays -function make_dataset(;n=100) +function make_dataset_and_scm(;n=100) rng = StableRNG(123) W = rand(rng, Uniform(), n) T = rand(rng, [0, 1], n) - y = 3W .+ T .+ T.*W + rand(rng, Normal(0, 0.05), n) - return ( - y = y, + Y = 3W .+ T .+ T.*W + rand(rng, Normal(0, 0.05), n) + dataset = ( + Y = Y, W = W, T = categorical(T) ) + scm = StaticConfoundedModel(:Y, :T, :W) + return dataset, scm end -basicCM(val) = CM( - outcome = :y, - treatment = (T=val,), - confounders = [:W] -) - @testset "Test cov" begin n = 10 X = rand(n, 2) @@ -37,37 +33,50 @@ basicCM(val) = CM( end @testset "Test composition CM(1) - CM(0) = ATE(1,0)" begin - dataset = make_dataset(;n=1000) - η_spec = NuisanceSpec( - LinearRegressor(), - LogisticClassifier(lambda=0) - ) + dataset, scm = make_dataset_and_scm(;n=1000) # Conditional Mean T = 1 - CM_result₁, cache = tmle(basicCM(1), η_spec, dataset, verbosity=0) + CM₁ = CM( + scm = scm, + outcome = :Y, + treatment = (T=1,) + ) + CM_result₁, _ = tmle(CM₁, dataset, verbosity=0) # Conditional Mean T = 0 - CM_result₀, cache = tmle!(cache, basicCM(0), verbosity=0) + CM₀ = CM( + scm = scm, + outcome = :Y, + treatment = (T=0,) + ) + CM_result₀, _ = tmle(CM₀, dataset, verbosity=0) # Via Composition CM_result_composed = compose(-, CM_result₁.tmle, CM_result₀.tmle) # Via ATE ATE₁₀ = ATE( - outcome = :y, + scm=scm, + outcome = :Y, treatment = (T=(case=1, control=0),), - confounders = [:W] ) - ATE_result₁₀, cache = tmle!(cache, ATE₁₀, verbosity=0) + ATE_result₁₀, _ = tmle(ATE₁₀, dataset, verbosity=0) @test TMLE.estimate(ATE_result₁₀.tmle) ≈ TMLE.estimate(CM_result_composed) atol = 1e-7 @test var(ATE_result₁₀.tmle) ≈ var(CM_result_composed) atol = 1e-7 end @testset "Test compose multidimensional function" begin - dataset = make_dataset(;n=1000) - η_spec = NuisanceSpec( - LinearRegressor(), - LogisticClassifier(lambda=0) + dataset, scm = make_dataset_and_scm(;n=1000) + CM₁ = CM( + scm = scm, + outcome = :Y, + treatment = (T=1,) + ) + CM_result₁, _ = tmle(CM₁, dataset, verbosity=0) + + CM₀ = CM( + scm = scm, + outcome = :Y, + treatment = (T=0,) ) - CM_result₁, cache = tmle(basicCM(1), η_spec, dataset, verbosity=0) - CM_result₀, cache = tmle!(cache, basicCM(0), verbosity=0) + CM_result₀, _ = tmle(CM₀, dataset, verbosity=0) f(x, y) = [x^2 - y, x/y, 2x + 3y] CM_result_composed = compose(f, CM_result₁.tmle, CM_result₀.tmle) diff --git a/test/double_robustness_iate.jl b/test/double_robustness_iate.jl index f89f2695..fb89060d 100644 --- a/test/double_robustness_iate.jl +++ b/test/double_robustness_iate.jl @@ -39,8 +39,8 @@ function binary_outcome_binary_treatment_pb(;n=100) W = float(W) T₁ = categorical(T₁) T₂ = categorical(T₂) - y = categorical(y) - + Y = categorical(y) + dataset = (T₁=T₁, T₂=T₂, W₁=W[:, 1], W₂=W[:, 2], W₃=W[:, 3], Y=Y) # Compute the theoretical IATE Wcomb = [1 1 1; 1 1 0; @@ -59,7 +59,10 @@ function binary_outcome_binary_treatment_pb(;n=100) temp -= μy_fn(w, [0], [1])[1] IATE += temp*0.5*0.5*0.5 end - return (T₁=T₁, T₂=T₂, W₁=W[:, 1], W₂=W[:, 2], W₃=W[:, 3], y=y), IATE + # Define the SCM + scm = StaticConfoundedModel([:Y], [:T₁, :T₂], [:W₁, :W₂, :W₃]) + + return dataset, scm, IATE end @@ -95,7 +98,7 @@ function binary_outcome_categorical_treatment_pb(;n=100) # Sampling y from T, W: μy = μy_fn(W, T, Hmach) y = [rand(rng, Bernoulli(μy[i])) for i in 1:n] - + dataset = (T₁=T.T₁, T₂=T.T₂, W₁=W[:, 1], W₂=W[:, 2], W₃=W[:, 3], Y=categorical(y)) # Compute the theoretical IATE for the query # (CC, AT) against (CG, AA) Wcomb = [1 1 1; @@ -117,7 +120,10 @@ function binary_outcome_categorical_treatment_pb(;n=100) temp -= μy_fn(w, (T₁=categorical(["CG"], levels=levels₁), T₂=categorical(["AT"], levels=levels₂)), Hmach)[1] IATE += temp*0.5*0.5*0.5 end - return (T₁=T.T₁, T₂=T.T₂, W₁=W[:, 1], W₂=W[:, 2], W₃=W[:, 3], y=categorical(y)), IATE + # Define the SCM + scm = StaticConfoundedModel([:Y], [:T₁, :T₂], [:W₁, :W₂, :W₃]) + + return dataset, scm, IATE end @@ -143,6 +149,7 @@ function continuous_outcome_binary_treatment_pb(;n=100) T₁ = categorical(T₁) T₂ = categorical(T₂) + dataset = (T₁=T₁, T₂=T₂, W₁=W[:, 1], W₂=W[:, 2], W₃=W[:, 3], Y=y) # Compute the theoretical ATE Wcomb = [1 1 1; 1 1 0; @@ -162,110 +169,108 @@ function continuous_outcome_binary_treatment_pb(;n=100) IATE += temp*0.5*0.5*0.5 end - return (T₁=T₁, T₂=T₂, W₁=W[:, 1], W₂=W[:, 2], W₃=W[:, 3], y=y), IATE + # Define the SCM + scm = StaticConfoundedModel([:Y], [:T₁, :T₂], [:W₁, :W₂, :W₃]) + + return dataset, scm, IATE end @testset "Test Double Robustness IATE on binary_outcome_binary_treatment_pb" begin - dataset, Ψ₀ = binary_outcome_binary_treatment_pb(n=10_000) + dataset, scm, Ψ₀ = binary_outcome_binary_treatment_pb(n=10_000) Ψ = IATE( - outcome=:y, + scm = scm, + outcome=:Y, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), - confounders = [:W₁, :W₂, :W₃] ) # When Q is misspecified but G is well specified - η_spec = NuisanceSpec( - ConstantClassifier(), - LogisticClassifier(lambda=0) - ) - tmle_result, cache = tmle(Ψ, η_spec, dataset, verbosity=0); + scm.Y.model = TreatmentTransformer() |> ConstantClassifier() + scm.T₁.model = LogisticClassifier(lambda=0) + scm.T₂.model = LogisticClassifier(lambda=0) + + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-9) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away @test tmle_result.initial == 0 # When Q is well specified but G is misspecified - η_spec = NuisanceSpec( - LogisticClassifier(lambda=0), - ConstantClassifier() - ) + scm.Y.model = TreatmentTransformer() |> LogisticClassifier(lambda=0) + scm.T₁.model = ConstantClassifier() + scm.T₂.model = ConstantClassifier() - tmle_result, cache = tmle!(cache, η_spec, verbosity=0); + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-7) + test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_mean_inf_curve_almost_zero(tmle_result; atol=1e-9) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away @test tmle_result.initial ≈ -0.0 atol=1e-1 end @testset "Test Double Robustness IATE on continuous_outcome_binary_treatment_pb" begin - dataset, Ψ₀ = continuous_outcome_binary_treatment_pb(n=10_000) + dataset, scm,Ψ₀ = continuous_outcome_binary_treatment_pb(n=10_000) Ψ = IATE( - outcome=:y, + scm=scm, + outcome=:Y, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), - confounders = [:W₁, :W₂, :W₃] ) # When Q is misspecified but G is well specified - η_spec = NuisanceSpec( - MLJModels.DeterministicConstantRegressor(), - LogisticClassifier(lambda=0) - ) + scm.Y.model = TreatmentTransformer() |> MLJModels.DeterministicConstantRegressor() + scm.T₁.model = LogisticClassifier(lambda=0) + scm.T₂.model = LogisticClassifier(lambda=0) - tmle_result, cache = tmle(Ψ, η_spec, dataset, verbosity=0) + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away @test tmle_result.initial == 0 # When Q is well specified but G is misspecified - η_spec = NuisanceSpec( - cont_interacter, - ConstantClassifier() - ) + scm.Y.model = TreatmentTransformer() |> cont_interacter + scm.T₁.model = ConstantClassifier() + scm.T₂.model = ConstantClassifier() - tmle_result, cache = tmle!(cache, η_spec, verbosity=0) + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end @testset "Test Double Robustness IATE on binary_outcome_categorical_treatment_pb" begin - dataset, Ψ₀ = binary_outcome_categorical_treatment_pb(n=10_000) + dataset, scm, Ψ₀ = binary_outcome_categorical_treatment_pb(n=30_000) Ψ = IATE( - outcome=:y, + scm=scm, + outcome=:Y, treatment=(T₁=(case="CC", control="CG"), T₂=(case="AT", control="AA")), - confounders = [:W₁, :W₂, :W₃] ) # When Q is misspecified but G is well specified - η_spec = NuisanceSpec( - ConstantClassifier(), - LogisticClassifier(lambda=0) - ) + scm.Y.model = TreatmentTransformer() |> ConstantClassifier() + scm.T₁.model = LogisticClassifier(lambda=0) + scm.T₂.model = LogisticClassifier(lambda=0) - tmle_result, cache = tmle(Ψ, η_spec, dataset, verbosity=0); + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away @test tmle_result.initial == 0 # When Q is well specified but G is misspecified - η_spec = NuisanceSpec( - cat_interacter, - ConstantClassifier() - ) + scm.Y.model = TreatmentTransformer() |> cat_interacter + scm.T₁.model = ConstantClassifier() + scm.T₂.model = ConstantClassifier() - tmle_result, cache = tmle!(cache, η_spec, verbosity=0) + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test tmle_result.initial ≈ -0.01 atol=1e-2 + @test tmle_result.initial ≈ -0.02 atol=1e-2 end diff --git a/test/fit_nuisance.jl b/test/fit_nuisance.jl deleted file mode 100644 index d039584e..00000000 --- a/test/fit_nuisance.jl +++ /dev/null @@ -1,105 +0,0 @@ -module TestFitNuisance - -using Test -using TMLE -using StableRNGs -using Distributions -using LogExpFunctions -using Random -using DataFrames -using CategoricalArrays -using MLJLinearModels -using MLJBase - -function make_dataset(;n=100, rng=StableRNG(123)) - μy_cont(W₁, W₂, W₃, T₁) = 2W₁ .- 4W₂ .- 3W₃ .+ 2T₁ - μy_bin(W₁, W₂, W₃, T₁) = logistic.(μy_cont(W₁, W₂, W₃, T₁)) - W₁ = rand(rng, Normal(), n) - W₂ = rand(rng, Normal(), n) - W₃ = rand(rng, Normal(), n) - μT₁ = logistic.(0.3W₁ .- 1.4W₂ .+ 1.9W₃) - T₁ = rand(rng, Uniform(), n) .< μT₁ - Ybin = rand(rng, Uniform(), n) .< μy_bin(W₁, W₂, W₃, T₁) - Ycont = μy_cont(W₁, W₂, W₃, T₁) + rand(rng, Normal(), n) - return DataFrame( - W₁ = W₁, - W₂ = W₂, - W₃ = W₃, - T₁ = categorical(T₁), - Ybin = categorical(Ybin), - Ycont = Ycont - ) -end - -@testset "Test fit_nuisance! binary Y" begin - dataset = make_dataset(;n=50_000, rng=StableRNG(123)) - cache = TMLECache(dataset) - Ψ = ATE( - outcome=:Ybin, - treatment=(T₁=(case=true, control=false),), - confounders = [:W₁, :W₂, :W₃] - ) - η_spec = NuisanceSpec( - LogisticClassifier(lambda=0.), - LogisticClassifier(lambda=0.) - ) - update!(cache, Ψ, η_spec) - TMLE.fit_nuisance!(cache, verbosity=0) - # Check G fit - Gcoefs = fitted_params(cache.η.G).coefs - @test Gcoefs[1][2] ≈ 0.298 atol=0.001 # outcome is 0.3 - @test Gcoefs[2][2] ≈ -1.441 atol=0.001 # outcome is -1.4 - @test Gcoefs[3][2] ≈ 1.953 atol=0.001 # outcome is 1.9 - # Check Q fit - Qcoefs = fitted_params(cache.η.Q).coefs - @test Qcoefs[1][2] ≈ 2.000 atol=0.001 # outcome is 2 - @test Qcoefs[2][2] ≈ -4.006 atol=0.001 # outcome is -4 - @test Qcoefs[3][2] ≈ -3.046 atol=0.001 # outcome is -3 - @test Qcoefs[4][2] ≈ -2.062 atol=0.001 # outcome is 2 on T₁=1 but T₁=0 is encoded - @test fitted_params(cache.η.Q).intercept ≈ 2.053 atol=0.001 - - # Fluctuate the inital Q, the log_loss should decrease - TMLE.tmle_step!(cache; verbosity=0, threshold=1e-8) - ll_fluct = mean(log_loss(cache.data[:Qfluct], dataset.Ybin)) - ll_init = mean(log_loss(cache.data[:Q₀], dataset.Ybin)) - @test ll_fluct < ll_init -end - - -@testset "Test fit_nuisance! continuous Y" begin - dataset = make_dataset(;n=50_000, rng=StableRNG(123)) - cache = TMLECache(dataset) - Ψ = ATE( - outcome=:Ycont, - treatment=(T₁=(case=true, control=false),), - confounders = [:W₁, :W₂, :W₃] - ) - η_spec = NuisanceSpec( - LinearRegressor(), - LogisticClassifier(lambda=0.) - ) - update!(cache, Ψ, η_spec) - TMLE.fit_nuisance!(cache, verbosity=0) - # Check G fit - Gcoefs = fitted_params(cache.η.G).coefs - @test Gcoefs[1][2] ≈ 0.298 atol=0.001 # outcome is 0.3 - @test Gcoefs[2][2] ≈ -1.441 atol=0.001 # outcome is -1.4 - @test Gcoefs[3][2] ≈ 1.953 atol=0.001 # outcome is 1.9 - # Check Q fit - Qcoefs = fitted_params(cache.η.Q).coefs - @test Qcoefs[1][2] ≈ 2.001 atol=0.001 # outcome is 2 - @test Qcoefs[2][2] ≈ -4.002 atol=0.001 # outcome is -4 - @test Qcoefs[3][2] ≈ -2.999 atol=0.001 # outcome is -3 - @test Qcoefs[4][2] ≈ -1.988 atol=0.001 # outcome is 2 on T₁=1 but T₁=0 is encoded - @test fitted_params(cache.η.Q).intercept ≈ 1.996 atol=0.001 - - # Fluctuate the inital Q, the RMSE should decrease - TMLE.tmle_step!(cache; verbosity=0, threshold=1e-8) - rmse_fluct = rmse(mean.(cache.data[:Qfluct]), dataset.Ycont) - rmse_init = rmse(cache.data[:Q₀], dataset.Ycont) - @test rmse_fluct < rmse_init -end - -end - -true \ No newline at end of file diff --git a/test/missing_management.jl b/test/missing_management.jl index 4d491a76..f38b8085 100644 --- a/test/missing_management.jl +++ b/test/missing_management.jl @@ -14,18 +14,19 @@ function dataset_with_missing_and_ordered_treatment(;n=1000) rng = StableRNG(123) W = rand(rng, n) T = rand(rng, [0, 1, 2], n) - y = T + 3W + randn(rng, n) - dataset = DataFrame(W = W, T = categorical(T, ordered=true, levels=[0, 1, 2]), y = y) + Y = T + 3W + randn(rng, n) + dataset = DataFrame(W = W, T = categorical(T, ordered=true, levels=[0, 1, 2]), Y = Y) allowmissing!(dataset) dataset.W[1:5] .= missing dataset.T[6:10] .= missing - dataset.y[11:15] .= missing - return dataset + dataset.Y[11:15] .= missing + scm = StaticConfoundedModel(:Y, :T, :W, treatment_model = LogisticClassifier(lambda=0)) + return dataset, scm end @testset "Test nomissing" begin - dataset = dataset_with_missing_and_ordered_treatment(;n=100) + dataset, scm = dataset_with_missing_and_ordered_treatment(;n=100) # filter missing rows based on W column filtered = TMLE.nomissing(dataset, [:W]) @test filtered.W == dataset.W[6:end] @@ -37,19 +38,18 @@ end filtered = TMLE.nomissing(dataset) @test filtered.W == dataset.W[16:end] @test filtered.T == dataset.T[16:end] - @test filtered.y == dataset.y[16:end] + @test filtered.Y == dataset.Y[16:end] end @testset "Test estimation with missing values and ordered factor treatment" begin - dataset = dataset_with_missing_and_ordered_treatment(;n=1000) + dataset, scm = dataset_with_missing_and_ordered_treatment(;n=1000) Ψ = ATE( - outcome=:y, - confounders=[:W], + scm = scm, + outcome=:Y, treatment=(T=(case=1, control=0),)) - η_spec = NuisanceSpec(LinearRegressor(), LogisticClassifier(lambda=0)) - tmle_result, cache = tmle(Ψ, η_spec, dataset; verbosity=0) + tmle_result, fluctuation_mach = tmle(Ψ, dataset; verbosity=0) test_coverage(tmle_result, 1) - test_fluct_decreases_risk(cache; outcome_name=:y) + test_fluct_decreases_risk(Ψ, fluctuation_mach) end end diff --git a/test/runtests.jl b/test/runtests.jl index bf159d6c..416ea0f4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,5 @@ using Test @test include("missing_management.jl") @test include("composition.jl") @test include("treatment_transformer.jl") - @test include("fit_nuisance.jl") @test include("configuration.jl") end \ No newline at end of file From e4472532da154c3451607853facdbfbdcc836667 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 4 Aug 2023 21:33:58 +0100 Subject: [PATCH 018/151] fix estimads tests --- src/estimands.jl | 34 +++++++++++++----- test/estimands.jl | 87 +++++++++++++++++++++++++++++------------------ 2 files changed, 79 insertions(+), 42 deletions(-) diff --git a/src/estimands.jl b/src/estimands.jl index 7e0c479d..dc2071dc 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -63,12 +63,16 @@ CM₂ = CM( ) ``` """ -@option struct CM <: Estimand +@option struct ConditionalMean <: Estimand scm::StructuralCausalModel outcome::Symbol treatment::NamedTuple end +const CM = ConditionalMean + +name(::Type{CM}) = "CM" + ##################################################################### ### Average Treatment Effect ### ##################################################################### @@ -106,12 +110,16 @@ ATE₂ = ATE( ) ``` """ -@option struct ATE <: Estimand +@option struct AverageTreatmentEffect <: Estimand scm::StructuralCausalModel outcome::Symbol treatment::NamedTuple end +const ATE = AverageTreatmentEffect + +name(::Type{ATE}) = "ATE" + ##################################################################### ### Interaction Average Treatment Effect ### ##################################################################### @@ -142,12 +150,15 @@ IATE₁ = IATE( ) ``` """ -@option struct IATE <: Estimand +@option struct InteractionAverageTreatmentEffect <: Estimand scm::StructuralCausalModel outcome::Symbol treatment::NamedTuple end +const IATE = InteractionAverageTreatmentEffect + +name(::Type{IATE}) = "IATE" ##################################################################### ### Methods ### @@ -155,7 +166,15 @@ end CMCompositeEstimand = Union{CM, ATE, IATE} -Base.show(io::IO, Ψ::T) where T <: CMCompositeEstimand = println(io, T) +function Base.show(io::IO, Ψ::T) where T <: CMCompositeEstimand + param_string = string( + name(T), + "\n-----", + "\nOutcome: ", Ψ.outcome, + "\nTreatment: ", Ψ.treatment + ) + println(io, param_string) +end equations_to_fit(Ψ::CMCompositeEstimand) = (outcome_equation(Ψ), (Ψ.scm[t] for t in treatments(Ψ))...) @@ -236,10 +255,9 @@ namedtuples_from_dicts(d::Dict) = function param_key(Ψ::CMCompositeEstimand) return ( - join(Ψ.confounders, "_"), - join(keys(Ψ.treatment), "_"), - string(Ψ.outcome), - join(Ψ.covariates, "_") + join(values(confounders(Ψ))..., "_"), + join(treatments(Ψ), "_"), + string(outcome(Ψ)), ) end diff --git a/test/estimands.jl b/test/estimands.jl index 76077064..2c0a287d 100644 --- a/test/estimands.jl +++ b/test/estimands.jl @@ -121,8 +121,6 @@ end outcome =:Ycat, ) log_sequence = ( - (:info, "Structural Equation corresponding to variable Ycat already fitted, skipping. Set `force=true` to force refit."), - (:info, "Structural Equation corresponding to variable T₁ already fitted, skipping. Set `force=true` to force refit."), (:info, "Fitting Structural Equation corresponding to variable T₂."), ) @test_logs log_sequence... fit!(Ψ, dataset, verbosity=1) @@ -139,73 +137,94 @@ end ) log_sequence = ( (:info, "Fitting Structural Equation corresponding to variable Ycont."), - (:info, "Structural Equation corresponding to variable T₁ already fitted, skipping. Set `force=true` to force refit."), - (:info, "Structural Equation corresponding to variable T₂ already fitted, skipping. Set `force=true` to force refit."), ) @test_logs log_sequence... fit!(Ψ, dataset, verbosity=1) @test scm.Ycat.mach isa Machine @test scm.T₁.mach isa Machine @test scm.T₂.mach isa Machine @test scm.Ycont.mach isa Machine + + # Change a model + scm.Ycont.model = TreatmentTransformer() |> LinearRegressor(fit_intercept=false) + log_sequence = ( + (:info, "Fitting Structural Equation corresponding to variable Ycont."), + ) + @test_logs log_sequence... fit!(Ψ, dataset, verbosity=1) end @testset "Test optimize_ordering" begin + rng = StableRNG(123) + scm = SCM( + SE(:T₁, [:W₁, :W₂]), + SE(:T₂, [:W₁, :W₂, :W₃]), + SE(:Y₁, [:T₁, :T₂, :W₁, :W₂, :C₁]), + SE(:Y₂, [:T₂, :W₁, :W₂, :W₃]), + ) estimands = [ ATE( - outcome=:Y, + scm=scm, + outcome=:Y₁, treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), - confounders=[:W₁, :W₂] + ), + IATE( + scm=scm, + outcome=:Y₁, + treatment=(T₁=(case=1, control=0), T₂=(case="AA", control="CC")), ), ATE( - outcome=:Y, + scm=scm, + outcome=:Y₁, treatment=(T₁=(case=1, control=0),), - confounders=[:W₁] ), ATE( + scm=scm, outcome=:Y₂, - treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), - confounders=[:W₁, :W₂], + treatment=(T₂=(case="AC", control="CC"),), ), CM( - outcome=:Y, - treatment=(T₁=0,), - confounders=[:W₁], - covariates=[:C₁] + scm=scm, + outcome=:Y₂, + treatment=(T₂="AC",), ), IATE( - outcome=:Y, - treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), - confounders=[:W₁, :W₂] + scm=scm, + outcome=:Y₁, + treatment=(T₁=(case=0, control=1), T₂=(case="AC", control="CC"),), ), CM( - outcome=:Y, - treatment=(T₁=0,), - confounders=[:W₁] + scm=scm, + outcome=:Y₂, + treatment=(T₂="CC",), ), ATE( + scm=scm, outcome=:Y₂, - treatment=(T₁=(case=1, control=0),), - confounders=[:W₁], - covariates=[:C₁] + treatment=(T₂=(case="AA", control="CC"),), ), ATE( + scm=scm, outcome=:Y₂, - treatment=(T₁=(case=0, control=1), T₂=(case="AC", control="CC")), - confounders=[:W₁, :W₂], - covariates=[:C₂] + treatment=(T₂=(case="AA", control="AC"),), ), ] + # Test param_key + @test TMLE.param_key(estimands[1]) == ("W₁_W₂", "T₁_T₂", "Y₁") + @test TMLE.param_key(estimands[end]) == ("W₁_W₂_W₃", "T₂", "Y₂") # Non mutating function + estimands = shuffle(rng, estimands) ordered_estimands = optimize_ordering(estimands) expected_ordering = [ - ATE(:Y, (T₁ = (case = 1, control = 0),), [:W₁], Symbol[]), - CM(:Y, (T₁ = 0,), [:W₁], Symbol[]), - CM(:Y, (T₁ = 0,), [:W₁], [:C₁]), - ATE(:Y₂, (T₁ = (case = 1, control = 0),), [:W₁], [:C₁]), - ATE(:Y, (T₁ = (case = 1, control = 0), T₂ = (case = "AC", control = "CC")), [:W₁, :W₂], Symbol[]), - IATE(:Y, (T₁ = (case = 1, control = 0), T₂ = (case = "AC", control = "CC")), [:W₁, :W₂], Symbol[]), - ATE(:Y₂, (T₁ = (case = 1, control = 0), T₂ = (case = "AC", control = "CC")), [:W₁, :W₂], Symbol[]), - ATE(:Y₂, (T₁ = (case = 0, control = 1), T₂ = (case = "AC", control = "CC")), [:W₁, :W₂], [:C₂]) + # Y₁ + ATE(scm=scm, outcome=:Y₁, treatment=(T₁ = (case = 1, control = 0),)), + ATE(scm=scm, outcome=:Y₁, treatment=(T₁ = (case = 1, control = 0),T₂ = (case = "AC", control = "CC"))), + IATE(scm=scm, outcome=:Y₁, treatment=(T₁ = (case = 0, control = 1), T₂ = (case = "AC", control = "CC"),)), + IATE(scm=scm, outcome=:Y₁, treatment=(T₁ = (case = 1, control = 0),T₂ = (case = "AA", control = "CC"))), + # Y₂ + ATE(scm=scm, outcome=:Y₂, treatment=(T₂ = (case = "AA", control = "AC"),)), + CM(scm=scm, outcome=:Y₂, treatment=(T₂ = "CC",)), + ATE(scm=scm, outcome=:Y₂, treatment=(T₂ = (case = "AA", control = "CC"),)), + CM(scm=scm, outcome=:Y₂, treatment=(T₂ = "AC",)), + ATE(scm=scm, outcome=:Y₂, treatment=(T₂ = (case = "AC", control = "CC"),)), ] @test ordered_estimands == expected_ordering # Mutating function From 157cedc0b566571698c5d1a6404b3ab256b0adff Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 6 Aug 2023 15:16:42 +0100 Subject: [PATCH 019/151] rename some files and add more tests --- src/TMLE.jl | 2 +- src/{cache.jl => estimation.jl} | 44 ---- src/utils.jl | 54 +++++ test/cache.jl | 277 ------------------------ test/{warm_restart.jl => estimation.jl} | 25 ++- test/runtests.jl | 3 +- test/utils.jl | 41 +++- 7 files changed, 113 insertions(+), 333 deletions(-) rename src/{cache.jl => estimation.jl} (70%) delete mode 100644 test/cache.jl rename test/{warm_restart.jl => estimation.jl} (96%) diff --git a/src/TMLE.jl b/src/TMLE.jl index 0d465aed..5d7df5c1 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -42,7 +42,7 @@ include("scm.jl") include("treatment_transformer.jl") include("estimands.jl") include("utils.jl") -include("cache.jl") +include("estimation.jl") include("estimate.jl") include("configuration.jl") diff --git a/src/cache.jl b/src/estimation.jl similarity index 70% rename from src/cache.jl rename to src/estimation.jl index 2183b5c2..12fbccc2 100644 --- a/src/cache.jl +++ b/src/estimation.jl @@ -1,48 +1,4 @@ -NotIdentifiedError(reasons) = ArgumentError(string( - "The estimand is not identified for the following reasons: \n", join(reasons, "\n"))) - -""" - check_treatment_settings(settings::NamedTuple, levels, treatment_name) - -Checks the case/control values defining the treatment contrast are present in the dataset levels. - -Note: This method is for estimands like the ATE or IATE that have case/control treatment settings represented as -`NamedTuple`. -""" -function check_treatment_settings(settings::NamedTuple, levels, treatment_name) - for (key, val) in zip(keys(settings), settings) - any(val .== levels) || - throw(ArgumentError(string( - "The treatment variable ", treatment_name, "'s, '", key, "' value: '", val, - " in Ψ does not match any level of the corresponding variable in the dataset: ", string.(levels)))) - end -end - -""" - check_treatment_settings(setting, levels, treatment_name) - -Checks the value defining the treatment setting is present in the dataset levels. - -Note: This is for estimands like the CM that do not have case/control treatment settings -and are represented as simple values. -""" -function check_treatment_settings(setting, levels, treatment_name) - any(setting .== levels) || - throw(ArgumentError(string( - "The string representation: '", val, "' for treatment ", treatment_name, - " in Ψ does not match any level of the corresponding variable in the dataset: ", string.(levels)))) -end - -function check_estimand(Ψ::Estimand, dataset) - identified, reasons = isidentified(Ψ, dataset) - identified || throw(NotIdentifiedError(reasons)) - for treatment_name in treatments(Ψ) - treatment_levels = levels(Tables.getcolumn(dataset, treatment_name)) - treatment_settings = getproperty(Ψ.treatment, treatment_name) - check_treatment_settings(treatment_settings, treatment_levels, treatment_name) - end -end function tmle(Ψ::Estimand, dataset; verbosity=1, cache=true, force=false, threshold=1e-8, weighted_fluctuation=false) # Check the estimand against the dataset diff --git a/src/utils.jl b/src/utils.jl index 6a8748e9..6a19d0c6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -67,6 +67,60 @@ function indicator_values(indicators, T) return indic end +############################################################################### +## Estimand Validation +############################################################################### + +NotIdentifiedError(reasons) = ArgumentError(string( + "The estimand is not identified for the following reasons: \n\t- ", join(reasons, "\n\t- "))) + +AbsentLevelError(treatment_name, key, val, levels) = ArgumentError(string( + "The treatment variable ", treatment_name, "'s, '", key, "' level: '", val, + "' in Ψ does not match any level in the dataset: ", levels)) + +AbsentLevelError(treatment_name, val, levels) = ArgumentError(string( + "The treatment variable ", treatment_name, "'s, level: '", val, + "' in Ψ does not match any level in the dataset: ", levels)) + +""" + check_treatment_settings(settings::NamedTuple, levels, treatment_name) + +Checks the case/control values defining the treatment contrast are present in the dataset levels. + +Note: This method is for estimands like the ATE or IATE that have case/control treatment settings represented as +`NamedTuple`. +""" +function check_treatment_settings(settings::NamedTuple, levels, treatment_name) + for (key, val) in zip(keys(settings), settings) + any(val .== levels) || + throw(AbsentLevelError(treatment_name, key, val, levels)) + end +end + +""" + check_treatment_settings(setting, levels, treatment_name) + +Checks the value defining the treatment setting is present in the dataset levels. + +Note: This is for estimands like the CM that do not have case/control treatment settings +and are represented as simple values. +""" +function check_treatment_settings(setting, levels, treatment_name) + any(setting .== levels) || + throw( + AbsentLevelError(treatment_name, setting, levels)) +end + +function check_estimand(Ψ::Estimand, dataset) + identified, reasons = isidentified(Ψ, dataset) + identified || throw(NotIdentifiedError(reasons)) + for treatment_name in treatments(Ψ) + treatment_levels = levels(Tables.getcolumn(dataset, treatment_name)) + treatment_settings = getproperty(Ψ.treatment, treatment_name) + check_treatment_settings(treatment_settings, treatment_levels, treatment_name) + end +end + ############################################################################### ## Offset & Covariate ############################################################################### diff --git a/test/cache.jl b/test/cache.jl deleted file mode 100644 index 6c7a8c6c..00000000 --- a/test/cache.jl +++ /dev/null @@ -1,277 +0,0 @@ -module TestCache - -using Test -using TMLE -using MLJLinearModels -using MLJModels -using CategoricalArrays -using MLJBase -using StableRNGs -using Distributions - -function naive_dataset(;n=100) - rng = StableRNG(123) - W = rand(rng, Uniform(), n) - T = rand(rng, [0, 1], n) - y = 3W .+ T .+ T.*W + rand(rng, Normal(0, 0.05), n) - return ( - y = y, - W = W, - T = categorical(T) - ) -end - - -function fakefill!(cache) - n = 10 - X = (x=rand(n),) - y = rand(n) - mach = machine(LinearRegressor(), X, y) - cache.η = TMLE.NuisanceEstimands(mach, mach, mach, mach) -end - -function fakecache() - n = 10 - dataset = ( - W₁=vcat(missing, rand(n-1)), - W₂=vcat([missing, missing], rand(n-2)), - C₁=rand(n), - T₁=rand([1, 0], n), - T₂=rand([1, 0], n), - T₃=["CT", "CC", "TT", "TT", "TT", "CT", "CC", "TT", "CT", "CC"], - y=rand(n), - ynew=vcat([missing, missing, missing], rand(n-3)) - ) - cache = TMLE.TMLECache(dataset) - return cache -end - -@testset "Test check_treatment_values" begin - cache = fakecache() - Ψ = ATE( - outcome=:y, - treatment=(T₁=(case=1, control=0),), - confounders=[:W₁], - ) - TMLE.check_treatment_values(cache, Ψ) - Ψ = ATE( - outcome=:y, - treatment=( - T₁=(case=1, control=0), - T₂=(case=true, control=false),), - confounders=[:W₁], - ) - @test_throws ArgumentError(string("The 'case' string representation: 'true' for treatment", - " T₂ in Ψ does not match any level of the corresponding variable", - " in the dataset: [\"0\", \"1\"]")) TMLE.check_treatment_values(cache, Ψ) - - Ψ = CM( - outcome=:y, - treatment=(T₃="CT",), - confounders=[:W₁], - ) - TMLE.check_treatment_values(cache, Ψ) -end - -@testset "Test update!(cache, Ψ)" begin - cache = fakecache() - @test !isdefined(cache, :Ψ) - @test !isdefined(cache, :η_spec) - - Ψ = ATE( - outcome=:y, - treatment=(T₁=(case=1, control=0),), - confounders=[:W₁], - covariates=[:C₁] - ) - η_spec = NuisanceSpec( - RidgeRegressor(), - LogisticClassifier(lambda=0) - ) - TMLE.update!(cache, Ψ, η_spec) - @test cache.Ψ == Ψ - @test cache.η_spec == η_spec - @test cache.η.Q === nothing - @test cache.η.H === nothing - @test cache.η.G === nothing - @test cache.η.F === nothing - # Since no missing data is present, the columns are just propagated - # to the no_missing dataset in the sense of === - for colname in keys(cache.data[:no_missing]) - cache.data[:no_missing][colname] === cache.data[:source][colname] - end - # Pretend the cache nuisance estimands have been estimated - fakefill!(cache) - @test cache.η.H !== nothing - @test cache.η.G !== nothing - @test cache.η.Q !== nothing - @test cache.η.F !== nothing - # New treatment configuration and new confounders set - # Nuisance estimands can be reused - Ψ = ATE( - outcome=:y, - treatment=(T₁=(case=0, control=1),), - confounders=[:W₁, :W₂], - covariates=[:C₁] - ) - TMLE.update!(cache, Ψ) - @test cache.η.H !== nothing - @test cache.η.G === nothing - @test cache.η.Q === nothing - @test cache.η.F === nothing - @test cache.Ψ == Ψ - for colname in keys(cache.data[:no_missing]) - @test length(cache.data[:no_missing][colname]) == 8 - end - - # Change the outcome - # E[Y|X] must be refit - fakefill!(cache) - Ψ = ATE( - outcome=:ynew, - treatment=(T₁=(case=0, control=1),), - confounders=[:W₁, :W₂], - covariates=[:C₁] - ) - TMLE.update!(cache, Ψ) - @test cache.η.H !== nothing - @test cache.η.G !== nothing - @test cache.η.Q === nothing - @test cache.η.F === nothing - @test cache.Ψ == Ψ - for colname in keys(cache.data[:no_missing]) - @test length(cache.data[:no_missing][colname]) == 7 - end - - # Change the covariate - # E[Y|X] must be refit - fakefill!(cache) - Ψ = ATE( - outcome=:ynew, - treatment=(T₁=(case=0, control=1),), - confounders=[:W₁, :W₂], - ) - TMLE.update!(cache, Ψ) - @test cache.η.H !== nothing - @test cache.η.G !== nothing - @test cache.η.Q === nothing - @test cache.η.F === nothing - @test cache.Ψ == Ψ - - # Change only treatment setting - # only F needs refit - fakefill!(cache) - Ψ = ATE( - outcome=:ynew, - treatment=(T₁=(case=1, control=0),), - confounders=[:W₁, :W₂], - ) - TMLE.update!(cache, Ψ) - @test cache.η.H !== nothing - @test cache.η.G !== nothing - @test cache.η.Q !== nothing - @test cache.η.F === nothing - @test cache.Ψ == Ψ - - # Change the treatments - # all η must be refit - fakefill!(cache) - Ψ = ATE( - outcome=:ynew, - treatment=(T₂=(case=1, control=0),), - confounders=[:W₁, :W₂], - ) - TMLE.update!(cache, Ψ) - @test cache.η.H === nothing - @test cache.η.G === nothing - @test cache.η.Q === nothing - @test cache.η.F === nothing - @test cache.Ψ == Ψ -end - -@testset "Test update!(cache, η_spec)" begin - cache = fakecache() - Ψ = ATE( - outcome=:y, - treatment=(T₁=(case=1, control=0),), - confounders=[:W₁], - covariates=[:C₁] - ) - η_spec = NuisanceSpec( - RidgeRegressor(), - LogisticClassifier(lambda=0) - ) - TMLE.update!(cache, Ψ, η_spec) - fakefill!(cache) - # Change the Q learning stategy - η_spec = NuisanceSpec( - RidgeRegressor(lambda=10), - LogisticClassifier(lambda=0) - ) - TMLE.update!(cache, η_spec) - @test cache.η.H !== nothing - @test cache.η.G !== nothing - @test cache.η.Q === nothing - @test cache.η.F === nothing - @test cache.η_spec == η_spec - - # Change the G learning stategy - fakefill!(cache) - η_spec = NuisanceSpec( - RidgeRegressor(lambda=10), - LogisticClassifier(lambda=10) - ) - TMLE.update!(cache, η_spec) - @test cache.η.H !== nothing - @test cache.η.G === nothing - @test cache.η.Q !== nothing - @test cache.η.F === nothing - @test cache.η_spec == η_spec -end - -@testset "Test counterfactual_aggregate" begin - n=100 - dataset = naive_dataset(;n=n) - Ψ = ATE( - outcome = :y, - treatment = (T=(case=1, control=0),), - confounders = [:W] - ) - η_spec = NuisanceSpec( - MLJModels.ConstantRegressor(), - ConstantClassifier() - ) - # Nuisance estimand estimation - tmle_result, cache = tmle(Ψ, η_spec, dataset; verbosity=0); - # The inital estimator for Q is a constant predictor, - # its prediction is the same for each counterfactual - # and the initial aggregate is zero. - # Since G is also constant the output after fluctuating - # must be constant over the counterfactual dataset - counterfactual_aggregate, counterfactual_aggregateᵢ = TMLE.counterfactual_aggregates(cache; threshold=1e-8) - @test counterfactual_aggregateᵢ == zeros(n) - @test all(first(counterfactual_aggregate) .== counterfactual_aggregate) - - # Replacing Q with a linear regression - η_spec = NuisanceSpec( - LinearRegressor(), - ConstantClassifier() - ) - tmle!(cache, η_spec, verbosity=0); - X₁ = TMLE.Qinputs(cache.η.H, (W=dataset.W, T=categorical(ones(Int, n), levels=levels(dataset.T))), Ψ) - X₀ = TMLE.Qinputs(cache.η.H, (W=dataset.W, T=categorical(zeros(Int, n), levels=levels(dataset.T))), Ψ) - ŷ₁ = TMLE.expected_value(MLJBase.predict(cache.η.Q, X₁)) - ŷ₀ = TMLE.expected_value(MLJBase.predict(cache.η.Q, X₀)) - expected_cf_agg = ŷ₁ - ŷ₀ - counterfactual_aggregate, counterfactual_aggregateᵢ = TMLE.counterfactual_aggregates(cache; threshold=1e-8) - @test counterfactual_aggregateᵢ == expected_cf_agg - # This is the coefficient in the linear regression model - var, coef = fitted_params(cache.η.Q).coefs[2] - @test var == :T__0 - @test all(coef ≈ -x for x ∈ counterfactual_aggregateᵢ) - -end - -end - -true \ No newline at end of file diff --git a/test/warm_restart.jl b/test/estimation.jl similarity index 96% rename from test/warm_restart.jl rename to test/estimation.jl index 89b833e0..2cb188b1 100644 --- a/test/warm_restart.jl +++ b/test/estimation.jl @@ -20,7 +20,7 @@ Results derived by hand for this dataset: - ATE(y₁, (T₁, T₂), W₁, W₂, C₁) = E[-4C₁ + 1 + W₂] = -0.5 - ATE(y₂, T₂, W₁, W₂, C₁) = 10 """ -function build_dataset(;n=100) +function build_dataset_and_scm(;n=100) rng = StableRNG(123) # Confounders W₁ = rand(rng, Uniform(), n) @@ -31,21 +31,30 @@ function build_dataset(;n=100) T₁ = rand(rng, Uniform(), n) .< logistic.(0.5sin.(W₁) .- 1.5W₂) T₂ = rand(rng, Uniform(), n) .< logistic.(-3W₁ - 1.5W₂) # outcome | Confounders, Covariates, Treatments - y₁ = 1 .+ 2W₁ .+ 3W₂ .- 4C₁.*T₁ .+ T₁ + T₂.*W₂ .+ rand(rng, Normal(0, 0.1), n) - y₂ = 1 .+ 10T₂ .+ rand(rng, Normal(0, 0.1), n) - y₃ = 1 .+ 2W₁ .+ 3W₂ .- 4C₁.*T₁ .- 2T₂.*T₁.*W₂ .+ rand(rng, Normal(0, 0.1), n) - return ( + Y₁ = 1 .+ 2W₁ .+ 3W₂ .- 4C₁.*T₁ .+ T₁ + T₂.*W₂ .+ rand(rng, Normal(0, 0.1), n) + Y₂ = 1 .+ 10T₂ .+ rand(rng, Normal(0, 0.1), n) + Y₃ = 1 .+ 2W₁ .+ 3W₂ .- 4C₁.*T₁ .- 2T₂.*T₁.*W₂ .+ rand(rng, Normal(0, 0.1), n) + dataset = ( T₁ = categorical(T₁), T₂ = categorical(T₂), W₁ = W₁, W₂ = W₂, C₁ = C₁, - y₁ = y₁, - y₂ = y₂, - y₃ = y₃ + Y₁ = Y₁, + Y₂ = Y₂, + Y₃ = Y₃ ) + scm = SCM( + SE(:T₁, [:W₁, :W₂]), + SE(:T₂, [:W₁, :W₂]), + SE(:Y₁, [:T₁, :T₂, :W₁, :W₂, :C₁]), + SE(:Y₂, [:T₂]), + SE(:Y₃, [:T₁, :T₂, :W₁, :W₂, :C₁]), + ) + return dataset, scm end + table_types = (Tables.columntable, DataFrame) @testset "Test Warm restart: ATE single treatment, $tt" for tt in table_types diff --git a/test/runtests.jl b/test/runtests.jl index 416ea0f4..c37ef659 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,12 +2,11 @@ using Test @time begin @test include("non_regression_test.jl") - @test include("cache.jl") @test include("utils.jl") @test include("double_robustness_ate.jl") @test include("double_robustness_iate.jl") @test include("3points_interactions.jl") - @test include("warm_restart.jl") + @test include("estimation.jl") @test include("estimands.jl") @test include("missing_management.jl") @test include("composition.jl") diff --git a/test/utils.jl b/test/utils.jl index 543c8ddf..8f1b01ea 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -12,6 +12,45 @@ using MLJGLMInterface: LinearBinaryClassifier using MLJLinearModels using MLJModels +@testset "Test estimand validation" begin + dataset = ( + W₁ = [1., 2., 3.], + T = categorical([1, 2, 3]), + Y = [1., 2., 3.] + ) + scm = SCM( + SE(:T, [:W₁, :W₂]), + SE(:Y, [:T, :W₁, :W₂]) + ) + # Not identified estimand + Ψ = ATE( + scm=scm, + outcome=:Y, + treatment=(T=(case=1, control=0),), + ) + reasons = ["Confounding variable: W₂ is not in the dataset."] + @test_throws TMLE.NotIdentifiedError(reasons) TMLE.check_estimand(Ψ, dataset) + + # Treatment values absent from dataset + scm = SCM( + SE(:T, [:W₁, :W₂]), + SE(:Y, [:T, :W₁]) + ) + Ψ = ATE( + scm=scm, + outcome=:Y, + treatment=(T=(case=1, control=0),) + ) + @test_throws TMLE.AbsentLevelError("T", "control", 0, [1, 2, 3]) TMLE.check_estimand(Ψ, dataset) + + Ψ = CM( + scm=scm, + outcome=:Y, + treatment=(T=0,), + ) + @test_throws TMLE.AbsentLevelError("T", 0, [1, 2, 3]) TMLE.check_estimand(Ψ, dataset) +end + @testset "Test indicator_fns & indicator_values" begin scm = StaticConfoundedModel( [:Y], @@ -138,7 +177,7 @@ end ) Ψ = ATE( scm=scm, - outcome =:Y, + outcome=:Y, treatment=(T=(case="a", control="b"),), ) dataset = ( From b7f8f6fb908e69ec441c221f12c40547d2a58d1e Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 6 Aug 2023 17:16:39 +0100 Subject: [PATCH 020/151] update more tests --- src/scm.jl | 11 +- test/estimation.jl | 386 +++++++++++++-------------------------------- 2 files changed, 120 insertions(+), 277 deletions(-) diff --git a/src/scm.jl b/src/scm.jl index 958e0f02..ded196ce 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -1,10 +1,6 @@ ##################################################################### ### Structural Equation ### ##################################################################### - -SelfReferringEquationError(outcome) = - ArgumentError(string("Variable ", outcome, " appears on both sides of the equation.")) - mutable struct StructuralEquation outcome::Symbol parents::Vector{Symbol} @@ -20,7 +16,14 @@ StructuralEquation(outcome, parents; model=nothing) = StructuralEquation(outcome const SE = StructuralEquation +SelfReferringEquationError(outcome) = + ArgumentError(string("Variable ", outcome, " appears on both sides of the equation.")) + +NoModelError(eq::SE) = ArgumentError(string("It seems the following structural equation needs to be fitted.\n", + " Please provide a suitable model for it :\n\t⋆ ", eq)) + function MLJBase.fit!(eq::SE, dataset; verbosity=1, cache=true, force=false) + eq.model !== nothing || throw(NoModelError(eq)) # Fit if never fitted or if new model if eq.mach === nothing || eq.model != eq.mach.model verbosity >= 1 && @info(string("Fitting Structural Equation corresponding to variable ", outcome(eq), ".")) diff --git a/test/estimation.jl b/test/estimation.jl index 2cb188b1..d933c3e2 100644 --- a/test/estimation.jl +++ b/test/estimation.jl @@ -32,7 +32,7 @@ function build_dataset_and_scm(;n=100) T₂ = rand(rng, Uniform(), n) .< logistic.(-3W₁ - 1.5W₂) # outcome | Confounders, Covariates, Treatments Y₁ = 1 .+ 2W₁ .+ 3W₂ .- 4C₁.*T₁ .+ T₁ + T₂.*W₂ .+ rand(rng, Normal(0, 0.1), n) - Y₂ = 1 .+ 10T₂ .+ rand(rng, Normal(0, 0.1), n) + Y₂ = 1 .+ 10T₂ .+ W₂ .+ rand(rng, Normal(0, 0.1), n) Y₃ = 1 .+ 2W₁ .+ 3W₂ .- 4C₁.*T₁ .- 2T₂.*T₁.*W₂ .+ rand(rng, Normal(0, 0.1), n) dataset = ( T₁ = categorical(T₁), @@ -45,11 +45,11 @@ function build_dataset_and_scm(;n=100) Y₃ = Y₃ ) scm = SCM( - SE(:T₁, [:W₁, :W₂]), - SE(:T₂, [:W₁, :W₂]), - SE(:Y₁, [:T₁, :T₂, :W₁, :W₂, :C₁]), - SE(:Y₂, [:T₂]), - SE(:Y₃, [:T₁, :T₂, :W₁, :W₂, :C₁]), + SE(:T₁, [:W₁, :W₂], LogisticClassifier(lambda=0)), + SE(:T₂, [:W₁, :W₂], LogisticClassifier(lambda=0)), + SE(:Y₁, [:T₁, :T₂, :W₁, :W₂, :C₁], TreatmentTransformer() |> LinearRegressor()), + SE(:Y₂, [:T₂, :W₂], TreatmentTransformer() |> LinearRegressor()), + SE(:Y₃, [:T₁, :T₂, :W₁, :W₂, :C₁], TreatmentTransformer() |> LinearRegressor()), ) return dataset, scm end @@ -58,396 +58,236 @@ end table_types = (Tables.columntable, DataFrame) @testset "Test Warm restart: ATE single treatment, $tt" for tt in table_types - dataset = tt(build_dataset(;n=50_000)) + dataset, scm = build_dataset_and_scm(;n=50_000) + dataset = tt(dataset) # Define the estimand of interest Ψ = ATE( - outcome=:y₁, + scm=scm, + outcome=:Y₁, treatment=(T₁=(case=true, control=false),), - confounders=[:W₁, :W₂], - covariates=[:C₁] - ) - # Define the nuisance estimands specification - η_spec = NuisanceSpec( - LinearRegressor(), - LogisticClassifier(lambda=0) ) # Run TMLE log_sequence = ( - (:info, "Fitting the nuisance estimands..."), - (:info, "→ Fitting P(T|W)"), - (:info, "→ Fitting Encoder"), - (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance estimands..."), + (:info, "Fitting the required equations..."), + (:info, "Fitting Structural Equation corresponding to variable Y₁."), + (:info, "Fitting Structural Equation corresponding to variable T₁."), + (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, cache = @test_logs log_sequence... tmle(Ψ, η_spec, dataset; verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset; verbosity=1); # The TMLE covers the ground truth but the initial estimate does not Ψ₀ = -1 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₁) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Update the treatment specification # Nuisance estimands should not fitted again Ψ = ATE( - outcome=:y₁, + scm=scm, + outcome=:Y₁, treatment=(T₁=(case=false, control=true),), - confounders=[:W₁, :W₂], - covariates=[:C₁] ) log_sequence = ( - (:info, "Fitting the nuisance estimands..."), - (:info, "→ Reusing previous P(T|W)"), - (:info, "→ Reusing previous Encoder"), - (:info, "→ Reusing previous E[Y|X]"), - (:info, "Targeting the nuisance estimands..."), + (:info, "Fitting the required equations..."), + (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); Ψ₀ = 1 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₁) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - - # Remove the covariate variable, this will trigger the refit of Q - Ψ = ATE( - outcome=:y₁, - treatment=(T₁=(case=true, control=false),), - confounders=[:W₁, :W₂], - ) - log_sequence = ( - (:info, "Fitting the nuisance estimands..."), - (:info, "→ Reusing previous P(T|W)"), - (:info, "→ Reusing previous Encoder"), - (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance estimands..."), - (:info, "Done.") - ) - tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1, weighted_fluctuation=true); - Ψ₀ = -1 - @test tmle_result.estimand == Ψ - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₁) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - # Change the treatment - # This will trigger the refit of all η + # New marginal treatment estimation: T₂ + # This will trigger fit of the corresponding equation Ψ = ATE( - outcome=:y₁, + scm=scm, + outcome=:Y₁, treatment=(T₂=(case=true, control=false),), - confounders=[:W₁, :W₂], ) log_sequence = ( - (:info, "Fitting the nuisance estimands..."), - (:info, "→ Fitting P(T|W)"), - (:info, "→ Fitting Encoder"), - (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance estimands..."), + (:info, "Fitting the required equations..."), + (:info, "Fitting Structural Equation corresponding to variable T₂."), + (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); Ψ₀ = 0.5 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₁) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - # Remove a confounding variable - # This will trigger the refit of Q and G - # Since we are not accounting for all confounders - # we can't have ground truth coverage on this setting + # Change the outcome: Y₂ + # This will trigger the fit for the corresponding equation Ψ = ATE( - outcome=:y₁, + scm=scm, + outcome=:Y₂, treatment=(T₂=(case=true, control=false),), - confounders=[:W₁], ) log_sequence = ( - (:info, "Fitting the nuisance estimands..."), - (:info, "→ Fitting P(T|W)"), - (:info, "→ Reusing previous Encoder"), - (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance estimands..."), + (:info, "Fitting the required equations..."), + (:info, "Fitting Structural Equation corresponding to variable Y₂."), + (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1, weighted_fluctuation=true); - @test tmle_result.estimand == Ψ - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - - # Change the outcome - # This will trigger the refit of Q only - Ψ = ATE( - outcome=:y₂, - treatment=(T₂=(case=true, control=false),), - confounders=[:W₁], - ) - log_sequence = ( - (:info, "Fitting the nuisance estimands..."), - (:info, "→ Reusing previous P(T|W)"), - (:info, "→ Reusing previous Encoder"), - (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance estimands..."), - (:info, "Done.") - ) - tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); Ψ₀ = 10 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₂) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - # Adding a covariate - # This will trigger the refit of Q only + # At this point all equations have been fitted + # New estimand with 2 treatment variables Ψ = ATE( - outcome=:y₂, - treatment=(T₂=(case=true, control=false),), - confounders=[:W₁], - covariates=[:C₁] - ) - tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); - @test tmle_result.estimand == Ψ - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₂) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) -end - -@testset "Test Warm restart: ATE multiple treatment, $tt" for tt in table_types - dataset = tt(build_dataset(;n=50_000)) - # Define the estimand of interest - Ψ = ATE( - outcome=:y₁, + scm=scm, + outcome=:Y₁, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), - confounders=[:W₁, :W₂], - covariates=[:C₁] ) - # Define the nuisance estimands specification - η_spec = NuisanceSpec( - LinearRegressor(), - LogisticClassifier(lambda=0) + log_sequence = ( + (:info, "Fitting the required equations..."), + (:info, "Performing TMLE..."), + (:info, "Done.") ) - tmle_result, cache = tmle(Ψ, η_spec, dataset; verbosity=0); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); Ψ₀ = -0.5 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₁) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - # Let's switch case and control for T₂ + # Switching the T₂ setting Ψ = ATE( - outcome=:y₁, + scm=scm, + outcome=:Y₁, treatment=(T₁=(case=true, control=false), T₂=(case=false, control=true)), - confounders=[:W₁, :W₂], - covariates=[:C₁] ) - log_sequence = ( - (:info, "Fitting the nuisance estimands..."), - (:info, "→ Reusing previous P(T|W)"), - (:info, "→ Reusing previous Encoder"), - (:info, "→ Reusing previous E[Y|X]"), - (:info, "Targeting the nuisance estimands..."), - (:info, "Done.") - ) - tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); Ψ₀ = -1.5 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₁) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end @testset "Test Warm restart: CM, $tt" for tt in table_types - dataset = tt(build_dataset(;n=50_000)) + dataset, scm = build_dataset_and_scm(;n=50_000) + dataset = tt(dataset) + Ψ = CM( - outcome=:y₁, + scm=scm, + outcome=:Y₁, treatment=(T₁=true, T₂=true), - confounders=[:W₁, :W₂], - covariates=[:C₁] ) - η_spec = NuisanceSpec( - LinearRegressor(), - LogisticClassifier(lambda=0) + log_sequence = ( + (:info, "Fitting the required equations..."), + (:info, "Fitting Structural Equation corresponding to variable Y₁."), + (:info, "Fitting Structural Equation corresponding to variable T₁."), + (:info, "Fitting Structural Equation corresponding to variable T₂."), + (:info, "Performing TMLE..."), + (:info, "Done.") ) - tmle_result, cache = tmle(Ψ, η_spec, dataset; verbosity=0); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); Ψ₀ = 3 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₁) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Let's switch case and control for T₂ Ψ = CM( - outcome=:y₁, + scm=scm, + outcome=:Y₁, treatment=(T₁=true, T₂=false), - confounders=[:W₁, :W₂], - covariates=[:C₁] ) log_sequence = ( - (:info, "Fitting the nuisance estimands..."), - (:info, "→ Reusing previous P(T|W)"), - (:info, "→ Reusing previous Encoder"), - (:info, "→ Reusing previous E[Y|X]"), - (:info, "Targeting the nuisance estimands..."), + (:info, "Fitting the required equations..."), + (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); Ψ₀ = 2.5 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₁) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Change the outcome Ψ = CM( - outcome=:y₂, - treatment=(T₁=true, T₂=false), - confounders=[:W₁, :W₂], - covariates=[:C₁] + scm=scm, + outcome=:Y₂, + treatment=(T₂=false, ), ) log_sequence = ( - (:info, "Fitting the nuisance estimands..."), - (:info, "→ Reusing previous P(T|W)"), - (:info, "→ Reusing previous Encoder"), - (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance estimands..."), + (:info, "Fitting the required equations..."), + (:info, "Fitting Structural Equation corresponding to variable Y₂."), + (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); - Ψ₀ = 1 + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); + Ψ₀ = 1.5 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₂) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - # Change the treatment - Ψ = CM( - outcome=:y₂, - treatment=(T₂=true,), - confounders=[:W₁, :W₂], - covariates=[:C₁] - ) + # Change the treatment model + scm.T₂.model = LogisticClassifier(lambda=0.01) log_sequence = ( - (:info, "Fitting the nuisance estimands..."), - (:info, "→ Fitting P(T|W)"), - (:info, "→ Fitting Encoder"), - (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance estimands..."), + (:info, "Fitting the required equations..."), + (:info, "Fitting Structural Equation corresponding to variable T₂."), + (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); - Ψ₀ = 11 + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₂) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end @testset "Test Warm restart: pairwise IATE, $tt" for tt in table_types - dataset = tt(build_dataset(;n=50_000)) + dataset, scm = build_dataset_and_scm(;n=50_000) + dataset = tt(dataset) Ψ = IATE( - outcome=:y₃, + scm=scm, + outcome=:Y₃, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), - confounders=[:W₁, :W₂], - covariates=[:C₁] - ) - η_spec = NuisanceSpec( - LinearRegressor(), - LogisticClassifier(lambda=0) - ) - tmle_result, cache = tmle(Ψ, η_spec, dataset; verbosity=0); - Ψ₀ = -1 - @test tmle_result.estimand == Ψ - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₃) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - - # Remove covariate from fit - Ψ = IATE( - outcome=:y₃, - treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), - confounders=[:W₁, :W₂], - ) - log_sequence = ( - (:info, "Fitting the nuisance estimands..."), - (:info, "→ Reusing previous P(T|W)"), - (:info, "→ Reusing previous Encoder"), - (:info, "→ Fitting E[Y|X]"), - (:info, "Targeting the nuisance estimands..."), - (:info, "Done.") - ) - tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1, weighted_fluctuation=true); - @test tmle_result.estimand == Ψ - test_coverage(tmle_result, Ψ₀) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - - # Changing the treatments values - Ψ = IATE( - outcome=:y₃, - treatment=(T₁=(case=false, control=true), T₂=(case=true, control=false)), - confounders=[:W₁, :W₂], ) log_sequence = ( - (:info, "Fitting the nuisance estimands..."), - (:info, "→ Reusing previous P(T|W)"), - (:info, "→ Reusing previous Encoder"), - (:info, "→ Reusing previous E[Y|X]"), - (:info, "Targeting the nuisance estimands..."), + (:info, "Fitting the required equations..."), + (:info, "Fitting Structural Equation corresponding to variable Y₃."), + (:info, "Fitting Structural Equation corresponding to variable T₁."), + (:info, "Fitting Structural Equation corresponding to variable T₂."), + (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψ, verbosity=1); - Ψ₀ = - Ψ₀ + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); + Ψ₀ = -1 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₃) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) -end + # Changing some equations + scm.Y₃.model = TreatmentTransformer() |> RidgeRegressor(lambda=1e-5) + scm.T₂.model = LogisticClassifier(lambda=1e-5) -@testset "Test Warm restart: Both Ψ and η changed" begin - dataset = build_dataset(;n=50_000) - # Define the estimand of interest - Ψ = ATE( - outcome=:y₁, - treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), - confounders=[:W₁, :W₂], - covariates=[:C₁] - ) - # Define the nuisance estimands specification - η_spec = NuisanceSpec( - LinearRegressor(), - LogisticClassifier(lambda=0) - ) - tmle_result, cache = tmle(Ψ, η_spec, dataset; verbosity=0); - Ψ₀ = -0.5 - @test tmle_result.estimand == Ψ - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₁) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - - Ψnew = ATE( - outcome=:y₁, - treatment=(T₁=(case=false, control=true), T₂=(case=false, control=true)), - confounders=[:W₁, :W₂], - covariates=[:C₁] - ) - # Define the nuisance estimands specification - η_spec_new = NuisanceSpec( - LinearRegressor(), - LogisticClassifier(lambda=0.1) - ) log_sequence = ( - (:info, "Fitting the nuisance estimands..."), - (:info, "→ Reusing previous P(T|W)"), - (:info, "→ Reusing previous Encoder"), - (:info, "→ Reusing previous E[Y|X]"), - (:info, "Targeting the nuisance estimands..."), + (:info, "Fitting the required equations..."), + (:info, "Fitting Structural Equation corresponding to variable Y₃."), + (:info, "Fitting Structural Equation corresponding to variable T₂."), + (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, cache = @test_logs log_sequence... tmle!(cache, Ψnew, η_spec; verbosity=1); - Ψ₀ = 0.5 - @test tmle_result.estimand == Ψnew + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); + @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(cache; outcome_name=:y₁) + test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end From 9701cd7b10a2cfc61cb2ac1a87e534cb392a9e33 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 7 Aug 2023 10:22:37 +0100 Subject: [PATCH 021/151] tests working locally --- Project.toml | 4 +--- src/TMLE.jl | 2 -- src/estimands.jl | 17 ++++++++++++---- src/scm.jl | 17 ++++++++++------ test/configuration.jl | 38 ----------------------------------- test/double_robustness_ate.jl | 2 +- test/runtests.jl | 2 +- test/scm.jl | 15 +++----------- 8 files changed, 30 insertions(+), 67 deletions(-) delete mode 100644 test/configuration.jl diff --git a/Project.toml b/Project.toml index 97e19180..bfed3c91 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.11.4" [deps] AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" -Configurations = "5218b696-f38b-4ac9-8b61-a12ec717816d" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" @@ -27,7 +26,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractDifferentiation = "0.4, 0.5" CategoricalArrays = "0.10" -Configurations = "0.17" Distributions = "0.25" GLM = "1.8.2" HypothesisTests = "0.10" @@ -36,10 +34,10 @@ MLJBase = "0.19, 0.20, 0.21" MLJGLMInterface = "0.3.4" MLJModels = "0.15, 0.16" Missings = "1.0" +PrecompileTools = "1.1.1" PrettyTables = "2.2" TableOperations = "1.2" Tables = "1.6" YAML = "0.4" Zygote = "0.6" -PrecompileTools = "1.1.1" julia = "1.6, 1.7, 1" diff --git a/src/TMLE.jl b/src/TMLE.jl index 5d7df5c1..283602a9 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -14,7 +14,6 @@ using Distributions using Zygote using LogExpFunctions using YAML -using Configurations using PrecompileTools using PrettyTables using Random @@ -44,7 +43,6 @@ include("estimands.jl") include("utils.jl") include("estimation.jl") include("estimate.jl") -include("configuration.jl") # ############################################################################# # PRECOMPILATION WORKLOAD diff --git a/src/estimands.jl b/src/estimands.jl index dc2071dc..3dfd96ab 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -63,7 +63,7 @@ CM₂ = CM( ) ``` """ -@option struct ConditionalMean <: Estimand +struct ConditionalMean <: Estimand scm::StructuralCausalModel outcome::Symbol treatment::NamedTuple @@ -71,6 +71,9 @@ end const CM = ConditionalMean +CM(;scm, outcome, treatment) = CM(scm, outcome, treatment) +CM(scm; outcome, treatment) = CM(scm, outcome, treatment) + name(::Type{CM}) = "CM" ##################################################################### @@ -110,7 +113,7 @@ ATE₂ = ATE( ) ``` """ -@option struct AverageTreatmentEffect <: Estimand +struct AverageTreatmentEffect <: Estimand scm::StructuralCausalModel outcome::Symbol treatment::NamedTuple @@ -118,6 +121,9 @@ end const ATE = AverageTreatmentEffect +ATE(;scm, outcome, treatment) = ATE(scm, outcome, treatment) +ATE(scm; outcome, treatment) = ATE(scm, outcome, treatment) + name(::Type{ATE}) = "ATE" ##################################################################### @@ -150,7 +156,7 @@ IATE₁ = IATE( ) ``` """ -@option struct InteractionAverageTreatmentEffect <: Estimand +struct InteractionAverageTreatmentEffect <: Estimand scm::StructuralCausalModel outcome::Symbol treatment::NamedTuple @@ -158,6 +164,9 @@ end const IATE = InteractionAverageTreatmentEffect +IATE(;scm, outcome, treatment) = IATE(scm, outcome, treatment) +IATE(scm; outcome, treatment) = IATE(scm, outcome, treatment) + name(::Type{IATE}) = "IATE" ##################################################################### @@ -166,7 +175,7 @@ name(::Type{IATE}) = "IATE" CMCompositeEstimand = Union{CM, ATE, IATE} -function Base.show(io::IO, Ψ::T) where T <: CMCompositeEstimand +function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: CMCompositeEstimand param_string = string( name(T), "\n-----", diff --git a/src/scm.jl b/src/scm.jl index ded196ce..8a52747a 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -21,12 +21,14 @@ SelfReferringEquationError(outcome) = NoModelError(eq::SE) = ArgumentError(string("It seems the following structural equation needs to be fitted.\n", " Please provide a suitable model for it :\n\t⋆ ", eq)) - + +fit_message(eq::SE) = string("Fitting Structural Equation corresponding to variable ", outcome(eq), ".") + function MLJBase.fit!(eq::SE, dataset; verbosity=1, cache=true, force=false) eq.model !== nothing || throw(NoModelError(eq)) # Fit if never fitted or if new model if eq.mach === nothing || eq.model != eq.mach.model - verbosity >= 1 && @info(string("Fitting Structural Equation corresponding to variable ", outcome(eq), ".")) + verbosity >= 1 && @info(fit_message(eq)) data = nomissing(dataset, vcat(parents(eq), outcome(eq))) X = selectcols(data, parents(eq)) y = Tables.getcolumn(data, outcome(eq)) @@ -35,6 +37,7 @@ function MLJBase.fit!(eq::SE, dataset; verbosity=1, cache=true, force=false) eq.mach = mach # Otherwise only fit if force is true else + verbosity >= 1 && force === true && @info(fit_message(eq)) MLJBase.fit!(eq.mach, verbosity=verbosity-1, force=force) end end @@ -49,7 +52,7 @@ function string_repr(eq::SE; subscript="") return eq_string end -Base.show(io::IO, eq::SE) = println(io, string_repr(eq)) +Base.show(io::IO, ::MIME"text/plain", eq::SE) = println(io, string_repr(eq)) assign_model!(eq::SE, model::Nothing) = nothing assign_model!(eq::SE, model::Model) = eq.model = model @@ -69,8 +72,10 @@ end const SCM = StructuralCausalModel -StructuralCausalModel(equations::Vararg{SE}) = - StructuralCausalModel(Dict(outcome(eq) => eq for eq in equations)) +SCM() = SCM(Dict{Symbol, SE}()) + +SCM(equations::Vararg{SE}) = + SCM(Dict(outcome(eq) => eq for eq in equations)) equations(scm::SCM) = scm.equations @@ -88,7 +93,7 @@ function string_repr(scm::SCM) return scm_string end -Base.show(io::IO, scm::SCM) = println(io, string_repr(scm)) +Base.show(io::IO, ::MIME"text/plain", scm::SCM) = println(io, string_repr(scm)) function Base.push!(scm::SCM, eq::SE) diff --git a/test/configuration.jl b/test/configuration.jl deleted file mode 100644 index 950e16c3..00000000 --- a/test/configuration.jl +++ /dev/null @@ -1,38 +0,0 @@ -module TestConfigurations - -using Test -using TMLE -using Serialization - -@testset "Test configurations" begin - param_file = "estimands_sample.yaml" - estimands = [ - IATE(; - outcome=:Y1, - treatment=(T2 = (case = 1, control = 0), T1 = (case = 1, control = 0)), - confounders=[:W1, :W2], - covariates=Symbol[:C1] - ), - ATE(; - outcome=:Y3, - treatment=(T3 = (case = 1, control = 0), T1 = (case = "AC", control = "CC")), - confounders=[:W1], - covariates=Symbol[] - ), - CM(;outcome=:Y3, treatment=(T3 = "AC", T1 = "CC"), confounders=[:W1], covariates=Symbol[]) - ] - # Test YAML Serialization - estimands_to_yaml(param_file, estimands) - new_estimands = estimands_from_yaml(param_file) - @test new_estimands == estimands - rm(param_file) - # Test basic Serialization - serialize(param_file, estimands) - new_estimands = deserialize(param_file) - @test new_estimands == estimands - rm(param_file) -end - -end - -true \ No newline at end of file diff --git a/test/double_robustness_ate.jl b/test/double_robustness_ate.jl index 73cf478d..5b4865a5 100644 --- a/test/double_robustness_ate.jl +++ b/test/double_robustness_ate.jl @@ -235,7 +235,7 @@ end outcome = :Y, treatment = (T₁=(case=1., control=0.), T₂=(case=1., control=0.)), ) - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=1) + tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) diff --git a/test/runtests.jl b/test/runtests.jl index c37ef659..58a0cc54 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Test @time begin + @test include("scm.jl") @test include("non_regression_test.jl") @test include("utils.jl") @test include("double_robustness_ate.jl") @@ -11,5 +12,4 @@ using Test @test include("missing_management.jl") @test include("composition.jl") @test include("treatment_transformer.jl") - @test include("configuration.jl") end \ No newline at end of file diff --git a/test/scm.jl b/test/scm.jl index e72b6a98..6a6e710f 100644 --- a/test/scm.jl +++ b/test/scm.jl @@ -113,19 +113,10 @@ end @test isdefined(eq.mach, :data) end # Refit will not do anything - nofit_log_sequence = ( - (:info, "Structural Equation corresponding to variable Ycont already fitted, skipping. Set `force=true` to force refit."), - (:info, "Structural Equation corresponding to variable Ycat already fitted, skipping. Set `force=true` to force refit."), - (:info, "Structural Equation corresponding to variable T₁ already fitted, skipping. Set `force=true` to force refit."), - (:info, "Structural Equation corresponding to variable T₂ already fitted, skipping. Set `force=true` to force refit."), - ) + nofit_log_sequence = () @test_logs nofit_log_sequence... fit!(scm, dataset, verbosity = 1) - # Force refit and set cache to false - @test_logs fit_log_sequence... fit!(scm, dataset, verbosity = 1, force=true, cache=false) - for (key, eq) in equations(scm) - @test eq.mach isa Machine - @test !isdefined(eq.mach, :data) - end + # Force refit + @test_logs fit_log_sequence... fit!(scm, dataset, verbosity = 1, force=true) # Reset scm reset!(scm) From a6980ccc5deb79191edd20c6ef461df1a35b2667 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 7 Aug 2023 10:42:30 +0100 Subject: [PATCH 022/151] add more estimand checks --- src/estimands.jl | 24 ++++++++++++++++++++++++ test/estimands.jl | 17 +++++++++++++++++ test/utils.jl | 6 +++--- 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/src/estimands.jl b/src/estimands.jl index 3dfd96ab..4fb7d942 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -67,6 +67,10 @@ struct ConditionalMean <: Estimand scm::StructuralCausalModel outcome::Symbol treatment::NamedTuple + function ConditionalMean(scm, outcome, treatment) + check_parameter_against_scm(scm, outcome, treatment) + return new(scm, outcome, treatment) + end end const CM = ConditionalMean @@ -117,6 +121,10 @@ struct AverageTreatmentEffect <: Estimand scm::StructuralCausalModel outcome::Symbol treatment::NamedTuple + function AverageTreatmentEffect(scm, outcome, treatment) + check_parameter_against_scm(scm, outcome, treatment) + return new(scm, outcome, treatment) + end end const ATE = AverageTreatmentEffect @@ -160,6 +168,10 @@ struct InteractionAverageTreatmentEffect <: Estimand scm::StructuralCausalModel outcome::Symbol treatment::NamedTuple + function InteractionAverageTreatmentEffect(scm, outcome, treatment) + check_parameter_against_scm(scm, outcome, treatment) + return new(scm, outcome, treatment) + end end const IATE = InteractionAverageTreatmentEffect @@ -175,6 +187,18 @@ name(::Type{IATE}) = "IATE" CMCompositeEstimand = Union{CM, ATE, IATE} +VariableNotAChildInSCMError(variable) = ArgumentError(string("Variable ", variable, " is not associated with a Structural Equation in the SCM.")) +TreatmentMustBeInOutcomeParentsError(variable) = ArgumentError(string("Treatment variable ", variable, " must be a direct parent of the outcome.")) + +function check_parameter_against_scm(scm::SCM, outcome, treatment) + eqs = equations(scm) + haskey(eqs, outcome) || throw(VariableNotAChildInSCMError(outcome)) + for treatment_variable in keys(treatment) + haskey(eqs, treatment_variable) || throw(VariableNotAChildInSCMError(treatment_variable)) + treatment_variable ∈ parents(eqs[outcome]) || throw(TreatmentMustBeInOutcomeParentsError(treatment_variable)) + end +end + function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: CMCompositeEstimand param_string = string( name(T), diff --git a/test/estimands.jl b/test/estimands.jl index 2c0a287d..d6b27b20 100644 --- a/test/estimands.jl +++ b/test/estimands.jl @@ -21,6 +21,23 @@ using MLJGLMInterface :Y, :T, [:W₁, :W₂], covariates=:C₁ ) + # Parameter validation + @test_throws TMLE.VariableNotAChildInSCMError("UndefinedOutcome") CM( + scm, + treatment=(T=1,), + outcome=:UndefinedOutcome + ) + @test_throws TMLE.VariableNotAChildInSCMError("UndefinedTreatment") CM( + scm, + treatment=(UndefinedTreatment=1,), + outcome=:Y + ) + @test_throws TMLE.TreatmentMustBeInOutcomeParentsError("Y") CM( + scm, + treatment=(Y=1,), + outcome=:T + ) + Ψ = CM( scm=scm, treatment=(T=1,), diff --git a/test/utils.jl b/test/utils.jl index 8f1b01ea..33d91e81 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -77,7 +77,7 @@ end # ATE Ψ = ATE( scm=scm, - outcome=:y, + outcome=:Y, treatment=(T₁=(case="A", control="B"), T₂=(control=0, case=1)), ) indicator_fns = TMLE.indicator_fns(Ψ) @@ -90,7 +90,7 @@ end # 2-points IATE Ψ = IATE( scm=scm, - outcome=:y, + outcome=:Y, treatment=(T₁=(case="A", control="B"), T₂=(case=1, control=0)), ) indicator_fns = TMLE.indicator_fns(Ψ) @@ -105,7 +105,7 @@ end # 3-points IATE Ψ = IATE( scm=scm, - outcome=:y, + outcome=:Y, treatment=(T₁=(case="A", control="B"), T₂=(case=1, control=0), T₃=(control="D", case="C")), ) indicator_fns = TMLE.indicator_fns(Ψ) From 5d0e9a2e379141c40e486f3a070805ae465e67df Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 7 Aug 2023 10:48:18 +0100 Subject: [PATCH 023/151] more constructors --- src/scm.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/scm.jl b/src/scm.jl index 8a52747a..8db1193c 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -12,10 +12,11 @@ mutable struct StructuralEquation end end -StructuralEquation(outcome, parents; model=nothing) = StructuralEquation(outcome, parents, model) - const SE = StructuralEquation +SE(outcome, parents; model=nothing) = SE(outcome, parents, model) +SE(;outcome, parents, model=nothing) = SE(outcome, parents, model) + SelfReferringEquationError(outcome) = ArgumentError(string("Variable ", outcome, " appears on both sides of the equation.")) @@ -72,8 +73,7 @@ end const SCM = StructuralCausalModel -SCM() = SCM(Dict{Symbol, SE}()) - +SCM(;equations=Dict{Symbol, SE}()) = SCM(equations) SCM(equations::Vararg{SE}) = SCM(Dict(outcome(eq) => eq for eq in equations)) From a2064e95327e001d629d17d5b9555b25cd3128c4 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 7 Aug 2023 10:52:44 +0100 Subject: [PATCH 024/151] update export list --- src/TMLE.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index 283602a9..3911758b 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -26,11 +26,11 @@ import AbstractDifferentiation as AD export SE, StructuralEquation export StructuralCausalModel, SCM, StaticConfoundedModel export CM, ATE, IATE -export parents, fit!, reset!, isidentified, equations -export tmle, tmle! +export parents, fit!, reset!, isidentified, equations, \ + optimize_ordering, optimize_ordering! +export tmle export var, estimate, initial_estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose -export estimands_from_yaml, estimands_to_yaml, optimize_ordering, optimize_ordering! export TreatmentTransformer # ############################################################################# From f801af293edf70da3f71747776c62e6de1ae25e7 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 7 Aug 2023 11:28:01 +0100 Subject: [PATCH 025/151] add upstream methods --- src/estimands.jl | 9 ++------- src/scm.jl | 30 +++++++++++++++++++++++++++++- test/scm.jl | 17 +++++++++++++++++ 3 files changed, 48 insertions(+), 8 deletions(-) diff --git a/src/estimands.jl b/src/estimands.jl index 4fb7d942..6f588bd3 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -188,14 +188,14 @@ name(::Type{IATE}) = "IATE" CMCompositeEstimand = Union{CM, ATE, IATE} VariableNotAChildInSCMError(variable) = ArgumentError(string("Variable ", variable, " is not associated with a Structural Equation in the SCM.")) -TreatmentMustBeInOutcomeParentsError(variable) = ArgumentError(string("Treatment variable ", variable, " must be a direct parent of the outcome.")) +TreatmentMustBeInOutcomeParentsError(variable) = ArgumentError(string("Treatment variable ", variable, " must be a parent of the outcome.")) function check_parameter_against_scm(scm::SCM, outcome, treatment) eqs = equations(scm) haskey(eqs, outcome) || throw(VariableNotAChildInSCMError(outcome)) for treatment_variable in keys(treatment) haskey(eqs, treatment_variable) || throw(VariableNotAChildInSCMError(treatment_variable)) - treatment_variable ∈ parents(eqs[outcome]) || throw(TreatmentMustBeInOutcomeParentsError(treatment_variable)) + is_upstream(treatment_variable, outcome, scm) || throw(TreatmentMustBeInOutcomeParentsError(treatment_variable)) end end @@ -281,11 +281,6 @@ F_model(::Type{<:AbstractVector{<:Finite}}) = F_model(t::Type{Any}) = throw(ArgumentError("Cannot proceed with Q model with target_scitype $t")) -namedtuples_from_dicts(d) = d -namedtuples_from_dicts(d::Dict) = - NamedTuple{Tuple(keys(d))}([namedtuples_from_dicts(val) for val in values(d)]) - - function param_key(Ψ::CMCompositeEstimand) return ( join(values(confounders(Ψ))..., "_"), diff --git a/src/scm.jl b/src/scm.jl index 8db1193c..74b078c6 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -115,7 +115,8 @@ function Base.getproperty(scm::StructuralCausalModel, key::Symbol) return scm.equations[key] end -parents(scm::StructuralCausalModel, key::Symbol) = parents(getproperty(scm, key)) +parents(scm::StructuralCausalModel, key::Symbol) = + haskey(equations(scm), key) ? parents(scm[key]) : Symbol[] function MLJBase.fit!(scm::SCM, dataset; verbosity=1, cache=true, force=false) for eq in values(equations(scm)) @@ -128,6 +129,33 @@ function reset!(scm::SCM) reset!(eq) end end + +""" + is_upstream(var₁, var₂, scm::SCM) + +Checks whether var₁ is upstream of var₂ in the SCM. +""" +function is_upstream(var₁, var₂, scm::SCM) + upstream = parents(scm, var₂) + while true + if var₁ ∈ upstream + return true + else + upstream = vcat((parents(scm, v) for v in upstream)...) + end + isempty(upstream) && return false + end +end + +function upstream_variables(var, scm::SCM) + upstream = parents(scm, var) + current_upstream = upstream + while true + current_upstream = vcat((parents(scm, v) for v in current_upstream)...) + union!(upstream, current_upstream) + isempty(current_upstream) && return upstream + end +end ##################################################################### ### StaticConfoundedModel ### ##################################################################### diff --git a/test/scm.jl b/test/scm.jl index 6a6e710f..475910dc 100644 --- a/test/scm.jl +++ b/test/scm.jl @@ -125,5 +125,22 @@ end end end +@testset "Test is_upstream && upstream_variables" begin + scm = SCM( + SE(:T₁, [:W₁]), + SE(:T₂, [:W₂]), + SE(:Y, [:T₁, :T₂]), + ) + parents(scm, :Y) + @test TMLE.is_upstream(:T₁, :Y, scm) === true + @test TMLE.is_upstream(:W₁, :Y, scm) === true + @test TMLE.is_upstream(:Y, :T₁, scm) === false + + @test TMLE.upstream_variables(:Y, scm) == [:T₁, :T₂, :W₁, :W₂] + @test TMLE.upstream_variables(:T₁, scm) == [:W₁] + @test TMLE.upstream_variables(:W₁, scm) == Symbol[] + +end + end true \ No newline at end of file From eb180f7752e5027c39fa4d29d050a2bcc3912aff Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 7 Aug 2023 18:52:17 +0100 Subject: [PATCH 026/151] add more generality for accepted graph --- src/TMLE.jl | 6 +- src/adjustment.jl | 37 +++++++++ src/configuration.jl | 36 --------- src/estimands.jl | 179 +------------------------------------------ src/estimation.jl | 41 +++++++--- src/scm.jl | 42 ++++++++-- src/utils.jl | 11 ++- test/adjustment.jl | 28 +++++++ test/estimands.jl | 40 +++------- test/estimation.jl | 18 +++-- test/utils.jl | 37 +++------ 11 files changed, 178 insertions(+), 297 deletions(-) create mode 100644 src/adjustment.jl delete mode 100644 src/configuration.jl create mode 100644 test/adjustment.jl diff --git a/src/TMLE.jl b/src/TMLE.jl index 3911758b..7a089d05 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -26,12 +26,12 @@ import AbstractDifferentiation as AD export SE, StructuralEquation export StructuralCausalModel, SCM, StaticConfoundedModel export CM, ATE, IATE -export parents, fit!, reset!, isidentified, equations, \ - optimize_ordering, optimize_ordering! +export parents, fit!, reset!, isidentified, equations, optimize_ordering, optimize_ordering! export tmle export var, estimate, initial_estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose export TreatmentTransformer +export BackdoorAdjustment # ############################################################################# # INCLUDES @@ -40,10 +40,12 @@ export TreatmentTransformer include("scm.jl") include("treatment_transformer.jl") include("estimands.jl") +include("adjustment.jl") include("utils.jl") include("estimation.jl") include("estimate.jl") + # ############################################################################# # PRECOMPILATION WORKLOAD # ############################################################################# diff --git a/src/adjustment.jl b/src/adjustment.jl new file mode 100644 index 00000000..71572c4f --- /dev/null +++ b/src/adjustment.jl @@ -0,0 +1,37 @@ + +abstract type AdjustmentMethod end + +struct BackdoorAdjustment <: AdjustmentMethod + outcome_extra::Vector{Symbol} +end + +BackdoorAdjustment(;outcome_extra=[]) = BackdoorAdjustment(outcome_extra) + +""" + outcome_input_variables(treatments, confounding_variables) + +This is defined as the set of variable corresponding to: + +- Treatment variables +- Confounding variables +- Extra covariates +""" +outcome_input_variables(treatments, confounding_variables, extra_covariates) = unique(vcat( + treatments, + confounding_variables, + extra_covariates +)) + +function get_models_input_variables(adjustment_method::BackdoorAdjustment, Ψ::CMCompositeEstimand) + models_inputs = [] + for treatment in treatments(Ψ) + push!(models_inputs, parents(Ψ.scm, treatment)) + end + unique_confounders = unique(vcat(models_inputs...)) + push!( + models_inputs, + outcome_input_variables(treatments(Ψ), unique_confounders, adjustment_method.outcome_extra) + ) + variables = Tuple(vcat(treatments(Ψ), outcome(Ψ))) + return NamedTuple{variables}(models_inputs) +end diff --git a/src/configuration.jl b/src/configuration.jl deleted file mode 100644 index 38ec485a..00000000 --- a/src/configuration.jl +++ /dev/null @@ -1,36 +0,0 @@ -Configurations.from_dict(::Type{<:Estimand}, ::Type{Symbol}, s) = Symbol(s) -Configurations.from_dict(::Type{<:Estimand}, ::Type{NamedTuple}, s) = eval(Meta.parse(s)) - -function param_to_dict(Ψ::Estimand) - d = to_dict(Ψ, YAMLStyle) - d["type"] = typeof(Ψ) - return d -end - -""" - estimands_from_yaml(filepath) - -Read estimation `Estimands` from YAML file. -""" -function estimands_from_yaml(filepath) - yaml_dict = YAML.load_file(filepath; dicttype=Dict{String, Any}) - params_dicts = yaml_dict["Estimands"] - nparams = size(params_dicts, 1) - estimands = Vector{Estimand}(undef, nparams) - for index in 1:nparams - d = params_dicts[index] - type = pop!(d, "type") - estimands[index] = from_dict(eval(Meta.parse(type)), d) - end - return estimands -end - -""" - estimands_to_yaml(filepath, estimands::Vector{<:Estimand}) - -Write estimation `Estimands` to YAML file. -""" -function estimands_to_yaml(filepath, estimands) - d = Dict("Estimands" => [param_to_dict(Ψ) for Ψ in estimands]) - YAML.write_file(filepath, d) -end \ No newline at end of file diff --git a/src/estimands.jl b/src/estimands.jl index 6f588bd3..76a7d05c 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -1,17 +1,3 @@ -const causal_graph = """ - T ← W - ↘ ↙ - Y ← C - -## Notation: - -- Y: outcome -- T: treatment -- W: confounders -- C: covariates -- X = (W, C, T) -""" - ##################################################################### ### Abstract Estimand ### ##################################################################### @@ -20,49 +6,11 @@ A Estimand is a functional on distribution space Ψ: ℳ → ℜ. """ abstract type Estimand end -function MLJBase.fit!(Ψ::Estimand, dataset; verbosity=1, cache=true, force=false) - for eq in equations_to_fit(Ψ) - fit!(eq, dataset, verbosity=verbosity, cache=cache, force=force) - end -end ##################################################################### ### Conditional Mean ### ##################################################################### -""" -# CM: Conditional Mean - -Mathematical definition: - - Eₓ[E[Y|do(T=t), X]] - -# Assumed Causal graph: - -$causal_graph - -# Fields: - - outcome: A symbol identifying the outcome variable of interest - - treatment: A NamedTuple linking each treatment variable to a value - - confounders: Confounding variables affecting both the outcome and the treatment - - covariates: Optional extra variables affecting the outcome only - -# Examples: -```julia -CM₁ = CM( - outcome=:Y₁, - treatment=(T₁=1,), - confounders=[:W₁, :W₂], - covariates=[:C₁] -) - -CM₂ = CM( - outcome=:Y₂, - treatment=(T₁=1, T₂="A"), - confounders=[:W₁], -) -``` -""" struct ConditionalMean <: Estimand scm::StructuralCausalModel outcome::Symbol @@ -84,39 +32,6 @@ name(::Type{CM}) = "CM" ### Average Treatment Effect ### ##################################################################### -""" -# ATE: Average Treatment Effect - -Mathematical definition: - - Eₓ[E[Y|do(T=case), X]] - Eₓ[E[Y|do(T=control), X]] - -# Assumed Causal graph: - -$causal_graph - -# Fields: - - outcome: A symbol identifying the outcome variable of interest - - treatment: A NamedTuple linking each treatment variable to case/control values - - confounders: Confounding variables affecting both the outcome and the treatment - - covariates: Optional extra variables affecting the outcome only - -# Examples: -```julia -ATE₁ = ATE( - outcome=:Y₁, - treatment=(T₁=(case=1, control=0),), - confounders=[:W₁, :W₂], - covariates=[:C₁] -) - -ATE₂ = ATE( - outcome=:Y₂, - treatment=(T₁=(case=1, control=0), T₂=(case="A", control="B")), - confounders=[:W₁], -) -``` -""" struct AverageTreatmentEffect <: Estimand scm::StructuralCausalModel outcome::Symbol @@ -138,32 +53,6 @@ name(::Type{ATE}) = "ATE" ### Interaction Average Treatment Effect ### ##################################################################### -""" -# IATE: Interaction Average Treatment Effect - -Mathematical definition for pairwise interaction: - - Eₓ[E[Y|do(T₁=1, T₂=1), X]] - Eₓ[E[Y|do(T₁=1, T₂=0), X]] - Eₓ[E[Y|do(T₁=0, T₂=1), X]] + Eₓ[E[Y|do(T₁=0, T₂=0), X]] - -# Assumed Causal graph: - -$causal_graph - -# Fields: - - outcome: A symbol identifying the outcome variable of interest - - treatment: A NamedTuple linking each treatment variable to case/control values - - confounders: Confounding variables affecting both the outcome and the treatment - - covariates: Optional extra variables affecting the outcome only - -# Examples: -```julia -IATE₁ = IATE( - outcome=:Y₁, - treatment=(T₁=(case=1, control=0), T₂=(case="A", control="B")), - confounders=[:W₁], -) -``` -""" struct InteractionAverageTreatmentEffect <: Estimand scm::StructuralCausalModel outcome::Symbol @@ -209,70 +98,16 @@ function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: CMCompositeEsti println(io, param_string) end -equations_to_fit(Ψ::CMCompositeEstimand) = (outcome_equation(Ψ), (Ψ.scm[t] for t in treatments(Ψ))...) - outcome_equation(Ψ::CMCompositeEstimand) = Ψ.scm[outcome(Ψ)] -variable_not_indataset(role, variable) = string(role, " variable: ", variable, " is not in the dataset.") - -function isidentified(Ψ::CMCompositeEstimand, dataset) - reasons = String[] - sch = Tables.schema(dataset) - # Check outcome variable - outcome(Ψ) ∈ sch.names || push!( - reasons, - variable_not_indataset("Outcome", outcome(Ψ)) - ) - - # Check Treatment variable - for treatment in treatments(Ψ) - treatment ∈ sch.names || push!( - reasons, - variable_not_indataset("Treatment", treatment) - ) - end - - # Check adjustment set - for confounder in Set(vcat(confounders(Ψ)...)) - confounder ∈ sch.names || push!( - reasons, - variable_not_indataset("Confounding", confounder) - ) - end - - return length(reasons) == 0, reasons -end - -confounders(scm, variable, outcome) = - intersect(parents(scm, outcome), parents(scm, variable)) - -function confounders(Ψ::CMCompositeEstimand) - confounders_ = [] - treatments_ = Tuple(treatments(Ψ)) - for treatment in treatments_ - push!(confounders_, confounders(Ψ.scm, treatment, outcome(Ψ))) - end - return NamedTuple{treatments_}(confounders_) -end - -function confounders(dataset, Ψ::CMCompositeEstimand) - confounders_ = [] - treatments_ = Tuple(treatments(Ψ)) - for treatment in treatments_ - push!( - confounders_, - selectcols(dataset, confounders(Ψ.scm, treatment, outcome(Ψ))) - ) - end - return NamedTuple{treatments_}(confounders_) -end - treatments(Ψ::CMCompositeEstimand) = collect(keys(Ψ.treatment)) treatments(dataset, Ψ::CMCompositeEstimand) = selectcols(dataset, treatments(Ψ)) outcome(Ψ::CMCompositeEstimand) = Ψ.outcome outcome(dataset, Ψ::CMCompositeEstimand) = Tables.getcolumn(dataset, outcome(Ψ)) +confounders(dataset, Ψ) = (;(T => selectcols(dataset, keys(Ψ.scm[T].mach.data[1])) for T in treatments(Ψ))...) + F_model(::Type{<:AbstractVector{<:MLJBase.Continuous}}) = LinearRegressor(fit_intercept=false, offsetcol = :offset) @@ -283,7 +118,6 @@ F_model(t::Type{Any}) = throw(ArgumentError("Cannot proceed with Q model with ta function param_key(Ψ::CMCompositeEstimand) return ( - join(values(confounders(Ψ))..., "_"), join(treatments(Ψ), "_"), string(outcome(Ψ)), ) @@ -292,15 +126,6 @@ end """ optimize_ordering!(estimands::Vector{<:Estimand}) -Reorders the given estimands so that most nuisance estimands fits can be -reused. Given the assumed causal graph: - -$causal_graph - -and the requirements to estimate both p(T|W) and E[Y|W, T, C]. - -A natural ordering of the estimands in order to save computations is given by the -following variables ordering: (W, T, Y, C) """ optimize_ordering!(estimands::Vector{<:Estimand}) = sort!(estimands, by=param_key) diff --git a/src/estimation.jl b/src/estimation.jl index 12fbccc2..19229a10 100644 --- a/src/estimation.jl +++ b/src/estimation.jl @@ -1,11 +1,30 @@ +function MLJBase.fit!(Ψ::CMCompositeEstimand, dataset; adjustment_method=BackdoorAdjustment(), verbosity=1, force=false) + models_input_variables = get_models_input_variables(adjustment_method, Ψ) + for (variable, input_variables) in zip(keys(models_input_variables), models_input_variables) + fit!(Ψ.scm[variable], dataset; + input_variables=input_variables, + verbosity=verbosity, + force=force + ) + end +end - -function tmle(Ψ::Estimand, dataset; verbosity=1, cache=true, force=false, threshold=1e-8, weighted_fluctuation=false) +function tmle(Ψ::CMCompositeEstimand, dataset; + adjustment_method=BackdoorAdjustment(), + verbosity=1, + force=false, + threshold=1e-8, + weighted_fluctuation=false + ) # Check the estimand against the dataset - check_estimand(Ψ, dataset) + check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's equations verbosity >= 1 && @info "Fitting the required equations..." - fit!(Ψ, dataset; verbosity=verbosity, cache=cache, force=force) + fit!(Ψ, dataset; + adjustment_method=adjustment_method, + verbosity=verbosity, + force=force + ) # TMLE step verbosity >= 1 && @info "Performing TMLE..." fluctuation_mach = tmle_step(Ψ, verbosity=verbosity, threshold=threshold, weighted_fluctuation=weighted_fluctuation) @@ -25,11 +44,12 @@ function tmle_step(Ψ::CMCompositeEstimand; verbosity=1, threshold=1e-8, weighte # Compute offset offset = TMLE.compute_offset(MLJBase.predict(outcome_mach)) # Compute clever covariate and weights for weighted fluctuation mode - W = confounders(X, Ψ) T = treatments(X, Ψ) + W = confounders(X, Ψ) covariate, weights = TMLE.clever_covariate_and_weights( - Ψ.scm, T, W, indicator_fns(Ψ); - threshold=threshold, weighted_fluctuation=weighted_fluctuation + Ψ.scm, W, T, indicator_fns(Ψ); + threshold=threshold, + weighted_fluctuation=weighted_fluctuation ) # Fit fluctuation Xfluct = TMLE.fluctuation_input(covariate, offset) @@ -50,14 +70,15 @@ function counterfactual_aggregates(Ψ::CMCompositeEstimand, fluctuation_mach; th for (vals, sign) in indicators # Counterfactual dataset for a given treatment setting T_ct = TMLE.counterfactualTreatment(vals, Ttemplate) - X_ct = selectcols(merge(X, T_ct), parents(outcome_eq)) + X_ct = selectcols(merge(X, T_ct), keys(X)) + W = confounders(X_ct, Ψ) # Counterfactual predictions with the initial Q - ŷᵢ = MLJBase.predict(outcome_eq.mach, X_ct) + ŷᵢ = MLJBase.predict(outcome_eq.mach, X_ct) counterfactual_aggregateᵢ .+= sign .* expected_value(ŷᵢ) # Counterfactual predictions with F offset = compute_offset(ŷᵢ) covariate, _ = clever_covariate_and_weights( - Ψ.scm, T_ct, confounders(X, Ψ), indicators; + Ψ.scm, W, T_ct, indicators; threshold=threshold, weighted_fluctuation=weighted_fluctuation ) Xfluct_ct = fluctuation_input(covariate, offset) diff --git a/src/scm.jl b/src/scm.jl index 74b078c6..4f256f87 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -25,19 +25,47 @@ NoModelError(eq::SE) = ArgumentError(string("It seems the following structural e fit_message(eq::SE) = string("Fitting Structural Equation corresponding to variable ", outcome(eq), ".") -function MLJBase.fit!(eq::SE, dataset; verbosity=1, cache=true, force=false) + +equation_inputs(eq::SE, input_variables::Nothing) = parents(eq) +equation_inputs(eq::SE, input_variables) = input_variables + +""" + MLJBase.fit!(eq::SE, dataset; input_variables=nothing, verbosity=1, cache=true, force=false) + +Fits the outcome's Structural Equation using the dataset with inputs variables given by either: + +- The variables corresponding to parents(eq) if `input_variables`= nothing +- The alternative variables provided by `input_variables`. + +Extra keyword arguments are: + +- cache: Controls whether the associated MLJ.Machine will cache data. +- force: Controls whether to force the associated MLJ.Machine to refit even if neither the model or data has changed. +- verbosity: Controls the verbosity level +""" +function MLJBase.fit!(eq::SE, dataset; input_variables=nothing, verbosity=1, cache=true, force=false) eq.model !== nothing || throw(NoModelError(eq)) - # Fit if never fitted or if new model - if eq.mach === nothing || eq.model != eq.mach.model + # Fit when: never fitted OR model has changed OR inputs have changed + dofit = eq.mach === nothing || eq.model != eq.mach.model + if !dofit + if isdefined(eq.mach, :data) + dofit = Tables.columnnames(eq.mach.data[1]) !== Tuple(equation_inputs(eq, input_variables)) + else + dofit = true + end + end + + if dofit verbosity >= 1 && @info(fit_message(eq)) - data = nomissing(dataset, vcat(parents(eq), outcome(eq))) - X = selectcols(data, parents(eq)) + input_colnames = equation_inputs(eq, input_variables) + data = nomissing(dataset, vcat(input_colnames, outcome(eq))) + X = selectcols(data, input_colnames) y = Tables.getcolumn(data, outcome(eq)) mach = machine(eq.model, X, y, cache=cache) MLJBase.fit!(mach, verbosity=verbosity-1) eq.mach = mach - # Otherwise only fit if force is true - else + # Also refit if force is true + else verbosity >= 1 && force === true && @info(fit_message(eq)) MLJBase.fit!(eq.mach, verbosity=verbosity-1, force=force) end diff --git a/src/utils.jl b/src/utils.jl index 6a19d0c6..59293c0e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -111,9 +111,12 @@ function check_treatment_settings(setting, levels, treatment_name) AbsentLevelError(treatment_name, setting, levels)) end -function check_estimand(Ψ::Estimand, dataset) - identified, reasons = isidentified(Ψ, dataset) - identified || throw(NotIdentifiedError(reasons)) +""" + check_treatment_levels(Ψ::CMCompositeEstimand, dataset) + +Makes sure the defined treatment levels are present in the dataset. +""" +function check_treatment_levels(Ψ::CMCompositeEstimand, dataset) for treatment_name in treatments(Ψ) treatment_levels = levels(Tables.getcolumn(dataset, treatment_name)) treatment_settings = getproperty(Ψ.treatment, treatment_name) @@ -168,7 +171,7 @@ if `weighted_fluctuation = true`: where SpecialIndicator(t) is defined in `indicator_fns`. """ -function clever_covariate_and_weights(scm::SCM, T, W, indicator_fns; threshold=1e-8, weighted_fluctuation=false) +function clever_covariate_and_weights(scm::SCM, W, T, indicator_fns; threshold=1e-8, weighted_fluctuation=false) # Compute the indicator values indic_vals = TMLE.indicator_values(indicator_fns, T) weights = balancing_weights(scm, W, T, threshold=threshold) diff --git a/test/adjustment.jl b/test/adjustment.jl new file mode 100644 index 00000000..afe51e7a --- /dev/null +++ b/test/adjustment.jl @@ -0,0 +1,28 @@ +module TestAdjustment + +using Test +using TMLE + +@testset "Test BackdoorAdjustment" begin + scm = SCM( + SE(:Y, [:I, :V, :C]), + SE(:V, [:T, :K]), + SE(:T, [:W, :X]), + SE(:I, [:W, :G]), + ) + Ψ = CM(scm, treatment=(T=1,), outcome=:Y) + adjustment_method = BackdoorAdjustment() + @test TMLE.get_models_input_variables(adjustment_method, Ψ) == ( + T = [:W, :X], + Y = [:T, :W, :X] + ) + adjustment_method = BackdoorAdjustment(outcome_extra=[:C]) + @test TMLE.get_models_input_variables(adjustment_method, Ψ) == ( + T = [:W, :X], + Y = [:T, :W, :X, :C] + ) +end + +end + +true \ No newline at end of file diff --git a/test/estimands.jl b/test/estimands.jl index d6b27b20..99685a6e 100644 --- a/test/estimands.jl +++ b/test/estimands.jl @@ -47,13 +47,7 @@ using MLJGLMInterface @test TMLE.treatments(Ψ) == [:T] @test TMLE.treatments(dataset, Ψ) == (T=dataset.T,) @test TMLE.outcome(Ψ) == :Y - @test TMLE.confounders(Ψ) == (T=[:W₁, :W₂],) - @test TMLE.confounders(dataset, Ψ) == (T=(W₁=dataset.W₁, W₂=dataset.W₂),) - identified, reasons = isidentified(Ψ, dataset) - @test identified === true - @test reasons == [] - ηs = TMLE.equations_to_fit(Ψ) - @test ηs === (scm.Y, scm.T) + # ATE scm = SCM( SE(:Z, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁]), @@ -67,20 +61,7 @@ using MLJGLMInterface ) @test TMLE.treatments(Ψ) == [:T₁, :T₂] @test TMLE.outcome(Ψ) == :Z - @test TMLE.confounders(Ψ) == (T₁=[:W₁₁, :W₁₂], T₂=[:W₂₁]) - identified, reasons = isidentified(Ψ, dataset) - @test identified === false - expected_reasons = [ - "Outcome variable: Z is not in the dataset.", - "Treatment variable: T₁ is not in the dataset.", - "Treatment variable: T₂ is not in the dataset.", - "Confounding variable: W₁₂ is not in the dataset.", - "Confounding variable: W₂₁ is not in the dataset.", - "Confounding variable: W₁₁ is not in the dataset." - ] - @test reasons == expected_reasons - ηs = TMLE.equations_to_fit(Ψ) - @test ηs === (scm.Z, scm.T₁, scm.T₂) + # IATE Ψ = IATE( scm=scm, @@ -89,12 +70,6 @@ using MLJGLMInterface ) @test TMLE.treatments(Ψ) == [:T₁, :T₂] @test TMLE.outcome(Ψ) == :Z - @test TMLE.confounders(Ψ) == (T₁=[:W₁₁, :W₁₂], T₂=[:W₂₁]) - identified, reasons = isidentified(Ψ, dataset) - @test identified === false - @test reasons == expected_reasons - ηs = TMLE.equations_to_fit(Ψ) - @test ηs === (scm.Z, scm.T₁, scm.T₂) end @testset "Test fit!" begin @@ -125,9 +100,11 @@ end treatment=(T₁=1,), outcome=:Ycat ) - fit!(Ψ, dataset, verbosity=0) + fit!(Ψ, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=0) @test scm.Ycat.mach isa Machine + @test keys(scm.Ycat.mach.data[1]) == (:T₁, :W₁₁, :W₁₂, :C) @test scm.T₁.mach isa Machine + @test keys(scm.T₁.mach.data[1]) == (:W₁₁, :W₁₂) @test scm.Ycont.mach isa Nothing @test scm.T₂.mach isa Nothing @@ -139,9 +116,11 @@ end ) log_sequence = ( (:info, "Fitting Structural Equation corresponding to variable T₂."), + (:info, "Fitting Structural Equation corresponding to variable Ycat.") ) @test_logs log_sequence... fit!(Ψ, dataset, verbosity=1) @test scm.Ycat.mach isa Machine + keys(scm.Ycat.mach.data[1]) @test scm.T₁.mach isa Machine @test scm.T₂.mach isa Machine @test scm.Ycont.mach isa Nothing @@ -225,8 +204,8 @@ end ), ] # Test param_key - @test TMLE.param_key(estimands[1]) == ("W₁_W₂", "T₁_T₂", "Y₁") - @test TMLE.param_key(estimands[end]) == ("W₁_W₂_W₃", "T₂", "Y₂") + @test TMLE.param_key(estimands[1]) == ("T₁_T₂", "Y₁") + @test TMLE.param_key(estimands[end]) == ("T₂", "Y₂") # Non mutating function estimands = shuffle(rng, estimands) ordered_estimands = optimize_ordering(estimands) @@ -249,6 +228,7 @@ end @test estimands == ordered_estimands end + @testset "Test structs are concrete types" begin @test isconcretetype(ATE) @test isconcretetype(IATE) diff --git a/test/estimation.jl b/test/estimation.jl index d933c3e2..dab5134e 100644 --- a/test/estimation.jl +++ b/test/estimation.jl @@ -69,8 +69,8 @@ table_types = (Tables.columntable, DataFrame) # Run TMLE log_sequence = ( (:info, "Fitting the required equations..."), - (:info, "Fitting Structural Equation corresponding to variable Y₁."), (:info, "Fitting Structural Equation corresponding to variable T₁."), + (:info, "Fitting Structural Equation corresponding to variable Y₁."), (:info, "Performing TMLE..."), (:info, "Done.") ) @@ -102,7 +102,7 @@ table_types = (Tables.columntable, DataFrame) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # New marginal treatment estimation: T₂ - # This will trigger fit of the corresponding equation + # This will trigger fit of the corresponding equations Ψ = ATE( scm=scm, outcome=:Y₁, @@ -111,6 +111,7 @@ table_types = (Tables.columntable, DataFrame) log_sequence = ( (:info, "Fitting the required equations..."), (:info, "Fitting Structural Equation corresponding to variable T₂."), + (:info, "Fitting Structural Equation corresponding to variable Y₁."), (:info, "Performing TMLE..."), (:info, "Done.") ) @@ -141,7 +142,6 @@ table_types = (Tables.columntable, DataFrame) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - # At this point all equations have been fitted # New estimand with 2 treatment variables Ψ = ATE( scm=scm, @@ -150,6 +150,7 @@ table_types = (Tables.columntable, DataFrame) ) log_sequence = ( (:info, "Fitting the required equations..."), + (:info, "Fitting Structural Equation corresponding to variable Y₁."), (:info, "Performing TMLE..."), (:info, "Done.") ) @@ -166,6 +167,11 @@ table_types = (Tables.columntable, DataFrame) outcome=:Y₁, treatment=(T₁=(case=true, control=false), T₂=(case=false, control=true)), ) + log_sequence = ( + (:info, "Fitting the required equations..."), + (:info, "Performing TMLE..."), + (:info, "Done.") + ) tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); Ψ₀ = -1.5 @test tmle_result.estimand == Ψ @@ -185,9 +191,9 @@ end ) log_sequence = ( (:info, "Fitting the required equations..."), - (:info, "Fitting Structural Equation corresponding to variable Y₁."), (:info, "Fitting Structural Equation corresponding to variable T₁."), (:info, "Fitting Structural Equation corresponding to variable T₂."), + (:info, "Fitting Structural Equation corresponding to variable Y₁."), (:info, "Performing TMLE..."), (:info, "Done.") ) @@ -260,9 +266,9 @@ end ) log_sequence = ( (:info, "Fitting the required equations..."), - (:info, "Fitting Structural Equation corresponding to variable Y₃."), (:info, "Fitting Structural Equation corresponding to variable T₁."), (:info, "Fitting Structural Equation corresponding to variable T₂."), + (:info, "Fitting Structural Equation corresponding to variable Y₃."), (:info, "Performing TMLE..."), (:info, "Done.") ) @@ -279,8 +285,8 @@ end log_sequence = ( (:info, "Fitting the required equations..."), - (:info, "Fitting Structural Equation corresponding to variable Y₃."), (:info, "Fitting Structural Equation corresponding to variable T₂."), + (:info, "Fitting Structural Equation corresponding to variable Y₃."), (:info, "Performing TMLE..."), (:info, "Done.") ) diff --git a/test/utils.jl b/test/utils.jl index 33d91e81..4efc3cb2 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -18,22 +18,9 @@ using MLJModels T = categorical([1, 2, 3]), Y = [1., 2., 3.] ) - scm = SCM( - SE(:T, [:W₁, :W₂]), - SE(:Y, [:T, :W₁, :W₂]) - ) - # Not identified estimand - Ψ = ATE( - scm=scm, - outcome=:Y, - treatment=(T=(case=1, control=0),), - ) - reasons = ["Confounding variable: W₂ is not in the dataset."] - @test_throws TMLE.NotIdentifiedError(reasons) TMLE.check_estimand(Ψ, dataset) - # Treatment values absent from dataset scm = SCM( - SE(:T, [:W₁, :W₂]), + SE(:T, [:W₁]), SE(:Y, [:T, :W₁]) ) Ψ = ATE( @@ -41,14 +28,14 @@ using MLJModels outcome=:Y, treatment=(T=(case=1, control=0),) ) - @test_throws TMLE.AbsentLevelError("T", "control", 0, [1, 2, 3]) TMLE.check_estimand(Ψ, dataset) + @test_throws TMLE.AbsentLevelError("T", "control", 0, [1, 2, 3]) TMLE.check_treatment_levels(Ψ, dataset) Ψ = CM( scm=scm, outcome=:Y, treatment=(T=0,), ) - @test_throws TMLE.AbsentLevelError("T", 0, [1, 2, 3]) TMLE.check_estimand(Ψ, dataset) + @test_throws TMLE.AbsentLevelError("T", 0, [1, 2, 3]) TMLE.check_treatment_levels(Ψ, dataset) end @testset "Test indicator_fns & indicator_values" begin @@ -187,17 +174,17 @@ end W₂ = rand(7), W₃ = rand(7) ) - fit!(scm, dataset, verbosity=0) + fit!(Ψ, dataset, verbosity=0) indicator_fns = TMLE.indicator_fns(Ψ) T = TMLE.treatments(dataset, Ψ) @test T == (T=dataset.T,) W = TMLE.confounders(dataset, Ψ) - @test W == (T=(W₁=dataset.W₁, W₂=dataset.W₂, W₃=dataset.W₃),) - + @test W == (T = (W₁=dataset.W₁, W₂=dataset.W₂, W₃=dataset.W₃),) + @test TMLE.compute_offset(Ψ) == repeat([mean(dataset.Y)], 7) weighted_fluctuation = true bw = TMLE.balancing_weights(scm, W, T) - cov, w = TMLE.clever_covariate_and_weights(scm, T, W, indicator_fns, weighted_fluctuation=weighted_fluctuation) + cov, w = TMLE.clever_covariate_and_weights(scm, W, T, indicator_fns, weighted_fluctuation=weighted_fluctuation) @test cov == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] @test w == bw == [1.75, 3.5, 7.0, 1.75, 1.75, 3.5, 1.75] @@ -209,7 +196,7 @@ end intercept = 0.0 ) - cov, w = TMLE.clever_covariate_and_weights(scm, T, W, indicator_fns) + cov, w = TMLE.clever_covariate_and_weights(scm, W, T, indicator_fns) @test cov == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] unweighted_mach = TMLE.tmle_step(Ψ, weighted_fluctuation=false) @@ -239,7 +226,7 @@ end W₂ = rand(7), W₃ = rand(7) ) - fit!(scm, dataset, verbosity=0) + fit!(Ψ, dataset, verbosity=0) indicator_fns = TMLE.indicator_fns(Ψ) T = TMLE.treatments(dataset, Ψ) @test T == (T₁ = dataset.T₁, T₂ = dataset.T₂) @@ -251,7 +238,7 @@ end # Because the outcome is binary, the offset is the logit @test TMLE.compute_offset(Ψ) == repeat([0.28768207245178085], 7) - cov, w = TMLE.clever_covariate_and_weights(scm, T, W, indicator_fns) + cov, w = TMLE.clever_covariate_and_weights(scm, W, T, indicator_fns) @test cov ≈ [2.45, -3.266, -3.266, 2.45, 2.45, -6.125, 8.166] atol=1e-2 @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] @@ -285,7 +272,7 @@ end W = rand(7), ) - fit!(scm, dataset, verbosity=0) + fit!(Ψ, dataset, verbosity=0) indicator_fns = TMLE.indicator_fns(Ψ) T = TMLE.treatments(dataset, Ψ) @test T == (T₁ = dataset.T₁, T₂ = dataset.T₂, T₃ = dataset.T₃) @@ -299,7 +286,7 @@ end @test TMLE.compute_offset(Ψ) == repeat([4.0], 7) - cov, w = TMLE.clever_covariate_and_weights(scm, T, W, indicator_fns) + cov, w = TMLE.clever_covariate_and_weights(scm, W, T, indicator_fns) @test cov ≈ [0, 8.575, -21.4375, 8.575, 0, -4.2875, -4.2875] atol=1e-3 @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] From 3b6e05e60d2312254942b6f57fc8a444b8d89a7f Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 7 Aug 2023 19:26:33 +0100 Subject: [PATCH 027/151] add more docstrings --- src/adjustment.jl | 9 ++++++ src/estimands.jl | 80 ++++++++++++++++++++++++++++++++++++++++++++--- src/estimation.jl | 21 +++++++++++++ src/scm.jl | 63 +++++++++++++++++++++++++++++++++++++ 4 files changed, 168 insertions(+), 5 deletions(-) diff --git a/src/adjustment.jl b/src/adjustment.jl index 71572c4f..6f7ba79c 100644 --- a/src/adjustment.jl +++ b/src/adjustment.jl @@ -5,6 +5,15 @@ struct BackdoorAdjustment <: AdjustmentMethod outcome_extra::Vector{Symbol} end +""" + BackdoorAdjustment(;outcome_extra=[]) + +The adjustment set for each treatment variable is simply the set of direct parents in the +associated structural model. + +`outcome_extra` are optional additional variables that can be used to fit the outcome model +in order to improve inference. +""" BackdoorAdjustment(;outcome_extra=[]) = BackdoorAdjustment(outcome_extra) """ diff --git a/src/estimands.jl b/src/estimands.jl index 76a7d05c..6beaa668 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -6,11 +6,32 @@ A Estimand is a functional on distribution space Ψ: ℳ → ℜ. """ abstract type Estimand end - ##################################################################### ### Conditional Mean ### ##################################################################### +""" +# Conditional Mean / CM + +## Definition + +``CM(Y, T=t) = E[Y|do(T=t)]`` + +## Constructors + +- CM(;scm::SCM, outcome, treatment) +- CM(scm::SCM; outcome, treatment) + +where: + +- scm: is a `StructuralCausalModel` (see [`SCM`](@ref)) +- outcome: is a `Symbol` +- treatment: is a `NamedTuple` + +## Example + +Ψ = CM(scm, outcome=:Y, treatment=(T=1,)) +""" struct ConditionalMean <: Estimand scm::StructuralCausalModel outcome::Symbol @@ -23,15 +44,36 @@ end const CM = ConditionalMean -CM(;scm, outcome, treatment) = CM(scm, outcome, treatment) -CM(scm; outcome, treatment) = CM(scm, outcome, treatment) +CM(;scm::SCM, outcome, treatment) = CM(scm, outcome, treatment) +CM(scm::SCM; outcome, treatment) = CM(scm, outcome, treatment) name(::Type{CM}) = "CM" ##################################################################### ### Average Treatment Effect ### ##################################################################### +""" +# Average Treatment Effect / ATE + +## Definition + +``ATE(Y, T, case, control) = E[Y|do(T=case)] - E[Y|do(T=control)`` + +## Constructors + +- ATE(;scm::SCM, outcome, treatment) +- ATE(scm::SCM; outcome, treatment) +where: + +- scm: is a `StructuralCausalModel` (see [`SCM`](@ref)) +- outcome: is a `Symbol` +- treatment: is a `NamedTuple` + +## Example + +Ψ = ATE(scm, outcome=:Y, treatment=(T=(case=1,control=0),) +""" struct AverageTreatmentEffect <: Estimand scm::StructuralCausalModel outcome::Symbol @@ -44,8 +86,8 @@ end const ATE = AverageTreatmentEffect -ATE(;scm, outcome, treatment) = ATE(scm, outcome, treatment) -ATE(scm; outcome, treatment) = ATE(scm, outcome, treatment) +ATE(;scm::SCM, outcome, treatment) = ATE(scm, outcome, treatment) +ATE(scm::SCM; outcome, treatment) = ATE(scm, outcome, treatment) name(::Type{ATE}) = "ATE" @@ -53,6 +95,30 @@ name(::Type{ATE}) = "ATE" ### Interaction Average Treatment Effect ### ##################################################################### +""" +# Interaction Average Treatment Effect / IATE + +## Definition + +For two treatments with settings (1, 0): + +``IATE = E[Y|do(T₁=1, T₂=1)] - E[Y|do(T₁=1, T₂=0)] - E[Y|do(T₁=0, T₂=1)] + E[Y|do(T₁=0, T₂=0)]`` + +## Constructors + +- IATE(;scm::SCM, outcome, treatment) +- IATE(scm::SCM; outcome, treatment) + +where: + +- scm: is a `StructuralCausalModel` (see [`SCM`](@ref)) +- outcome: is a `Symbol` +- treatment: is a `NamedTuple` + +## Example + +Ψ = IATE(scm, outcome=:Y, treatment=(T₁=(case=1,control=0), T₂=(case=1,control=0)) +""" struct InteractionAverageTreatmentEffect <: Estimand scm::StructuralCausalModel outcome::Symbol @@ -74,6 +140,8 @@ name(::Type{IATE}) = "IATE" ### Methods ### ##################################################################### +AVAILABLE_ESTIMANDS = (CM, ATE, IATE) + CMCompositeEstimand = Union{CM, ATE, IATE} VariableNotAChildInSCMError(variable) = ArgumentError(string("Variable ", variable, " is not associated with a Structural Equation in the SCM.")) @@ -126,6 +194,8 @@ end """ optimize_ordering!(estimands::Vector{<:Estimand}) +Optimizes the order of the `estimands` to maximize reuse of +fitted equations in the associated SCM. """ optimize_ordering!(estimands::Vector{<:Estimand}) = sort!(estimands, by=param_key) diff --git a/src/estimation.jl b/src/estimation.jl index 19229a10..60f21bab 100644 --- a/src/estimation.jl +++ b/src/estimation.jl @@ -9,6 +9,27 @@ function MLJBase.fit!(Ψ::CMCompositeEstimand, dataset; adjustment_method=Backdo end end +""" + tmle(Ψ::CMCompositeEstimand, dataset; + adjustment_method=BackdoorAdjustment(), + verbosity=1, + force=false, + threshold=1e-8, + weighted_fluctuation=false + ) + +Performs Targeted Minimum Loss Based Estimation of the target estimand. + +## Arguments + +- Ψ: An estimand of interest. +- dataset: A table respecting the `Tables.jl` interface. +- adjustment_method: A confounding adjustment method. +- verbosity: Level of logging. +- force: To force refit of machines in the SCM . +- threshold: The balancing score will be bounded to respect this threshold. +- weighted_fluctuation: To use a weighted fluctuation instead of the vanilla TMLE, can improve stability. +""" function tmle(Ψ::CMCompositeEstimand, dataset; adjustment_method=BackdoorAdjustment(), verbosity=1, diff --git a/src/scm.jl b/src/scm.jl index 4f256f87..e5ff5a4a 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -1,6 +1,20 @@ ##################################################################### ### Structural Equation ### ##################################################################### +""" +# Structural Equation / SE + +## Constructors + +- SE(outcome, parents; model=nothing) +- SE(;outcome, parents, model=nothing) + +## Examples + +eq = SE(:Y, [:T, :W]) +eq = SE(:Y, [:T, :W], model = LinearRegressor()) + +""" mutable struct StructuralEquation outcome::Symbol parents::Vector{Symbol} @@ -95,6 +109,22 @@ parents(se::SE) = se.parents AlreadyAssignedError(key) = ArgumentError(string("Variable ", key, " is already assigned in the SCM.")) +""" +# Structural Causal Model / SCM + +## Constructors + +SCM(;equations=Dict{Symbol, SE}()) +SCM(equations::Vararg{SE}) + +## Examples + +scm = SCM( + SE(:Y, [:T, :W, :C]), + SE(:T, [:W]) +) + +""" struct StructuralCausalModel equations::Dict{Symbol, StructuralEquation} end @@ -191,6 +221,20 @@ end vcat_covariates(treatment, confounders, covariates::Nothing) = vcat(treatment, confounders) vcat_covariates(treatment, confounders, covariates) = vcat(treatment, confounders, covariates) +""" + StaticConfoundedModel( + outcome::Symbol, treatment::Symbol, confounders::Union{Symbol, AbstractVector{Symbol}}; + covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, + outcome_model = TreatmentTransformer() |> LinearRegressor(), + treatment_model = LinearBinaryClassifier() + ) + +Defines a classic Structural Causal Model with one outcome, one treatment, +a set of confounding variables and optional covariates influencing the outcome only. + +The `outcome_model` and `treatment_model` define the relationship between +the outcome (resp. treatment) and their ancestors. +""" function StaticConfoundedModel( outcome::Symbol, treatment::Symbol, confounders::Union{Symbol, AbstractVector{Symbol}}; covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, @@ -210,6 +254,25 @@ function StaticConfoundedModel( return StructuralCausalModel(Yeq, Teq) end +""" + StaticConfoundedModel( + outcomes::Vector{Symbol}, + treatments::Vector{Symbol}, + confounders::Union{Symbol, AbstractVector{Symbol}}; + covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, + outcome_model = TreatmentTransformer() |> LinearRegressor(), + treatment_model = LinearBinaryClassifier() + ) + +Defines a classic Structural Causal Model with multiple outcomes, multiple treatments, +a set of confounding variables and optional covariates influencing the outcomes only. + +All treatments are assumed to be direct parents of all outcomes. The confounding variables +are shared for all treatments. + +The `outcome_model` and `treatment_model` define the relationships between +the outcomes (resp. treatments) and their ancestors. +""" function StaticConfoundedModel( outcomes::Vector{Symbol}, treatments::Vector{Symbol}, From 1deee66220284399c4186bf1c1ff20789eb8da2f Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 8 Aug 2023 09:06:41 +0100 Subject: [PATCH 028/151] rename some structs --- src/TMLE.jl | 4 +-- src/estimate.jl | 61 ++++++++++++++++++++++++--------------------- src/estimation.jl | 4 +-- test/composition.jl | 4 +-- test/estimands.jl | 3 ++- 5 files changed, 41 insertions(+), 35 deletions(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index 7a089d05..b6a1e1c6 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -25,8 +25,8 @@ import AbstractDifferentiation as AD export SE, StructuralEquation export StructuralCausalModel, SCM, StaticConfoundedModel -export CM, ATE, IATE -export parents, fit!, reset!, isidentified, equations, optimize_ordering, optimize_ordering! +export CM, ATE, IATE, AVAILABLE_ESTIMANDS +export parents, fit!, reset!, equations, optimize_ordering, optimize_ordering! export tmle export var, estimate, initial_estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose diff --git a/src/estimate.jl b/src/estimate.jl index 910ddfb0..be0ce081 100644 --- a/src/estimate.jl +++ b/src/estimate.jl @@ -1,20 +1,25 @@ -abstract type AbstractTMLE end +abstract type AsymptoticallyLinearEstimate end -struct ALEstimate{T<:AbstractFloat} <: AbstractTMLE +struct TMLEstimate{T<:AbstractFloat} <: AsymptoticallyLinearEstimate Ψ̂::T IC::Vector{T} end -struct ComposedTMLE{T<:AbstractFloat} <: AbstractTMLE +struct OSEstimate{T<:AbstractFloat} <: AsymptoticallyLinearEstimate + Ψ̂::T + IC::Vector{T} +end + +struct ComposedTMLEstimate{T<:AbstractFloat} <: AsymptoticallyLinearEstimate Ψ̂::Array{T} σ̂::Matrix{T} end struct TMLEResult{P <: Estimand, T<:AbstractFloat} estimand::P - tmle::ALEstimate{T} - onestep::ALEstimate{T} + tmle::TMLEstimate{T} + onestep::OSEstimate{T} initial::T end @@ -23,7 +28,7 @@ function Base.show(io::IO, r::TMLEResult) onesteptest = OneSampleTTest(r.onestep) data = [ :TMLE estimate(r.tmle) confint(tmletest) pvalue(tmletest); - :OneStep estimate(r.onestep) confint(onesteptest) pvalue(onesteptest); + :OSE estimate(r.onestep) confint(onesteptest) pvalue(onesteptest); :Naive r.initial nothing nothing ] pretty_table(io, data;header=["Estimator", "Estimate", "95% Confidence Interval", "P-value"]) @@ -32,60 +37,60 @@ end """ - estimate(r::ALEstimate) + estimate(r::AsymptoticallyLinearEstimate) Retrieves the final estimate: after the TMLE step. """ -estimate(r::ALEstimate) = r.Ψ̂ +estimate(r::AsymptoticallyLinearEstimate) = r.Ψ̂ """ - estimate(r::ComposedTMLE) + estimate(r::ComposedTMLEstimate) Retrieves the final estimate: after the TMLE step. """ -estimate(r::ComposedTMLE) = length(r.Ψ̂) == 1 ? r.Ψ̂[1] : r.Ψ̂ +estimate(r::ComposedTMLEstimate) = length(r.Ψ̂) == 1 ? r.Ψ̂[1] : r.Ψ̂ """ - var(r::ALEstimate) + var(r::AsymptoticallyLinearEstimate) Computes the estimated variance associated with the estimate. """ -Statistics.var(r::ALEstimate) = var(r.IC)/size(r.IC, 1) +Statistics.var(r::AsymptoticallyLinearEstimate) = var(r.IC)/size(r.IC, 1) """ - var(r::ComposedTMLE) + var(r::ComposedTMLEstimate) Computes the estimated variance associated with the estimate. """ -Statistics.var(r::ComposedTMLE) = length(r.σ̂) == 1 ? r.σ̂[1] : r.σ̂ +Statistics.var(r::ComposedTMLEstimate) = length(r.σ̂) == 1 ? r.σ̂[1] : r.σ̂ """ - OneSampleZTest(r::ALEstimate, Ψ₀=0) + OneSampleZTest(r::AsymptoticallyLinearEstimate, Ψ₀=0) -Performs a Z test on the ALEstimate. +Performs a Z test on the AsymptoticallyLinearEstimate. """ -HypothesisTests.OneSampleZTest(r::ALEstimate, Ψ₀=0) = OneSampleZTest(estimate(r), std(r.IC), size(r.IC, 1), Ψ₀) +HypothesisTests.OneSampleZTest(r::AsymptoticallyLinearEstimate, Ψ₀=0) = OneSampleZTest(estimate(r), std(r.IC), size(r.IC, 1), Ψ₀) """ - OneSampleTTest(r::ALEstimate, Ψ₀=0) + OneSampleTTest(r::AsymptoticallyLinearEstimate, Ψ₀=0) -Performs a T test on the ALEstimate. +Performs a T test on the AsymptoticallyLinearEstimate. """ -HypothesisTests.OneSampleTTest(r::ALEstimate, Ψ₀=0) = OneSampleTTest(estimate(r), std(r.IC), size(r.IC, 1), Ψ₀) +HypothesisTests.OneSampleTTest(r::AsymptoticallyLinearEstimate, Ψ₀=0) = OneSampleTTest(estimate(r), std(r.IC), size(r.IC, 1), Ψ₀) """ - OneSampleTTest(r::ComposedTMLE, Ψ₀=0) + OneSampleTTest(r::ComposedTMLEstimate, Ψ₀=0) -Performs a T test on the ComposedTMLE. +Performs a T test on the ComposedTMLEstimate. """ -function HypothesisTests.OneSampleTTest(r::ComposedTMLE, Ψ₀=0) +function HypothesisTests.OneSampleTTest(r::ComposedTMLEstimate, Ψ₀=0) @assert length(r.Ψ̂) > 1 "OneSampleTTest is only implemeted for real-valued statistics." return OneSampleTTest(estimate(r), sqrt(var(r)), 1, Ψ₀) end """ - compose(f, estimation_results::Vararg{ALEstimate, N}) where N + compose(f, estimation_results::Vararg{AsymptoticallyLinearEstimate, N}) where N Provides an estimator of f(estimation_results...). @@ -118,7 +123,7 @@ Hence, the only thing we need to do is: # Arguments - f: An array-input differentiable map. -- estimation_results: 1 or more `ALEstimate` structs. +- estimation_results: 1 or more `AsymptoticallyLinearEstimate` structs. # Examples @@ -129,17 +134,17 @@ f(x, y) = [x^2 - y, y - 3x] compose(f, res₁, res₂) ``` """ -function compose(f, estimators::Vararg{ALEstimate, N}; backend=AD.ZygoteBackend()) where N +function compose(f, estimators::Vararg{AsymptoticallyLinearEstimate, N}; backend=AD.ZygoteBackend()) where N Σ = cov(estimators...) estimates = [estimate(r) for r in estimators] f₀, Js = AD.value_and_jacobian(backend, f, estimates...) J = hcat(Js...) n = size(first(estimators).IC, 1) σ₀ = (J*Σ*J')/n - return ComposedTMLE(collect(f₀), σ₀) + return ComposedTMLEstimate(collect(f₀), σ₀) end -function Statistics.cov(estimators::Vararg{ALEstimate, N}) where N +function Statistics.cov(estimators::Vararg{AsymptoticallyLinearEstimate, N}) where N X = hcat([r.IC for r in estimators]...) return Statistics.cov(X, dims=1, corrected=true) end diff --git a/src/estimation.jl b/src/estimation.jl index 60f21bab..b7826ca9 100644 --- a/src/estimation.jl +++ b/src/estimation.jl @@ -51,8 +51,8 @@ function tmle(Ψ::CMCompositeEstimand, dataset; fluctuation_mach = tmle_step(Ψ, verbosity=verbosity, threshold=threshold, weighted_fluctuation=weighted_fluctuation) # Estimation results after TMLE IC, Ψ̂, ICᵢ, Ψ̂ᵢ = gradient_and_estimates(Ψ, fluctuation_mach, threshold=threshold, weighted_fluctuation=weighted_fluctuation) - tmle_result = ALEstimate(Ψ̂, IC) - one_step_result = ALEstimate(Ψ̂ᵢ + mean(ICᵢ), ICᵢ) + tmle_result = TMLEstimate(Ψ̂, IC) + one_step_result = OSEstimate(Ψ̂ᵢ + mean(ICᵢ), ICᵢ) verbosity >= 1 && @info "Done." TMLEResult(Ψ, tmle_result, one_step_result, Ψ̂ᵢ), fluctuation_mach diff --git a/test/composition.jl b/test/composition.jl index 2f5bf3f8..4f2cf571 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -25,8 +25,8 @@ end @testset "Test cov" begin n = 10 X = rand(n, 2) - ER₁ = TMLE.ALEstimate(1., X[:, 1]) - ER₂ = TMLE.ALEstimate(0., X[:, 2]) + ER₁ = TMLE.TMLEstimate(1., X[:, 1]) + ER₂ = TMLE.TMLEstimate(0., X[:, 2]) Σ = cov(ER₁, ER₂) @test size(Σ) == (2, 2) @test Σ == cov(X) diff --git a/test/estimands.jl b/test/estimands.jl index 99685a6e..847b13d7 100644 --- a/test/estimands.jl +++ b/test/estimands.jl @@ -233,7 +233,8 @@ end @test isconcretetype(ATE) @test isconcretetype(IATE) @test isconcretetype(CM) - @test isconcretetype(TMLE.ALEstimate{Float64}) + @test isconcretetype(TMLE.TMLEstimate{Float64}) + @test isconcretetype(TMLE.OSEstimate{Float64}) @test isconcretetype(TMLE.TMLEResult{ATE, Float64}) @test isconcretetype(TMLE.TMLEResult{IATE, Float64}) @test isconcretetype(TMLE.TMLEResult{CM, Float64}) From 3da0049e260c377bf4feb22a031569efb6781a59 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 9 Aug 2023 14:17:32 +0100 Subject: [PATCH 029/151] add more docs --- Project.toml | 11 + docs/make.jl | 23 ++- docs/src/index.md | 121 +++-------- docs/src/resources.md | 18 ++ docs/src/user_guide.md | 324 ------------------------------ docs/src/user_guide/adjustment.md | 12 ++ docs/src/user_guide/estimands.md | 106 ++++++++++ docs/src/user_guide/estimation.md | 158 +++++++++++++++ docs/src/user_guide/scm.md | 90 +++++++++ docs/src/walk_through.md | 151 ++++++++++++++ src/TMLE.jl | 11 +- src/estimate.jl | 3 +- src/scm.jl | 4 +- src/treatment_transformer.jl | 1 + 14 files changed, 597 insertions(+), 436 deletions(-) create mode 100644 docs/src/resources.md delete mode 100644 docs/src/user_guide.md create mode 100644 docs/src/user_guide/adjustment.md create mode 100644 docs/src/user_guide/estimands.md create mode 100644 docs/src/user_guide/estimation.md create mode 100644 docs/src/user_guide/scm.md create mode 100644 docs/src/walk_through.md diff --git a/Project.toml b/Project.toml index bfed3c91..41e5c4d8 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,14 @@ uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf" authors = ["Olivier Labayle"] version = "0.11.4" +[weakdeps] +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" + +[extensions] +GraphsTMLEExt = ["Graphs", "CairoMakie", "GraphMakie"] + [deps] AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" @@ -40,4 +48,7 @@ TableOperations = "1.2" Tables = "1.6" YAML = "0.4" Zygote = "0.6" +Graphs = "1.8.0" +GraphMakie = "0.5.5" +CairoMakie = "0.10.7" julia = "1.6, 1.7, 1" diff --git a/docs/make.jl b/docs/make.jl index bb3c81bf..2f0edc14 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -4,13 +4,13 @@ using Literate DocMeta.setdocmeta!(TMLE, :DocTestSetup, :(using TMLE); recursive=true) -# Generate Literate markdown pages -@info "Building Literate pages..." -examples_dir = joinpath(@__DIR__, "../examples") -build_examples_dir = joinpath(@__DIR__, "src", "examples/") -for file in readdir(examples_dir) - Literate.markdown(joinpath(examples_dir, file), build_examples_dir;documenter=true) -end +## Generate Literate markdown pages +# @info "Building Literate pages..." +# examples_dir = joinpath(@__DIR__, "../examples") +# build_examples_dir = joinpath(@__DIR__, "src", "examples/") +# for file in readdir(examples_dir) +# Literate.markdown(joinpath(examples_dir, file), build_examples_dir;documenter=true) +# end @info "Running makedocs..." makedocs(; @@ -25,12 +25,15 @@ makedocs(; ), pages=[ "Home" => "index.md", - "User Guide" => "user_guide.md", + "Walk Through" => "walk_through.md", + "User Guide" => [joinpath("user_guide", f) for f in + ("mathematical_setting.md", "scm.md", "estimands.md", "estimation.md", "adjustment.md")], "Examples" => [ #joinpath("examples", "introduction_to_targeted_learning.md"), - joinpath("examples", "super_learning.md"), - joinpath("examples", "double_robustness.md") + # joinpath("examples", "super_learning.md"), + # joinpath("examples", "double_robustness.md") ], + "Resources" => "resources.md", "API Reference" => "api.md" ], ) diff --git a/docs/src/index.md b/docs/src/index.md index 1074b245..35a456d2 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -6,73 +6,7 @@ CurrentModule = TMLE ## Overview -TMLE.jl is a Julia implementation of the Targeted Minimum Loss-Based Estimation ([TMLE](https://link.springer.com/book/10.1007/978-1-4419-9782-1)) framework. If you are interested in efficient and unbiased estimation of some causal effect, you are in the right place. Since TMLE uses machine-learning methods to estimate nuisance estimands, the present package is based upon [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/). This means that any model respecting the MLJ interface can be used for the estimation of nuisance estimands. - -## Mathematical setting - -Even though the TMLE framework is not restricted to causal inference, the scope of the present package is currently restricted to so called "causal effects" and the following causal graph is assumed throughout: - -```@raw html -
-Causal Model -
-``` - -whith the following general interpration of variables: - -- Y: The Target -- T: The Treatment -- W: Confounders -- C: Extra Covariates - -This graph encodes a factorization of the joint probability distribution $P_0$ that we assume generated the observed data: - -```math -P_0(Y, T, W, C) = P_0(Y|T, W, C) \cdot P_0(T|W) \cdot P_0(W) \cdot P_0(C) -``` - -Usually, we don't have much information about $P_0$ and don't want to make unrealistic assumptions, thus we will simply state that $P_0 \in \mathcal{M}$ where $\mathcal{M}$ is the space of all probability distributions. It turns out that most target quantities of interest in causal inference can be expressed as real-valued functions of $P_0$, denoted by: $\Psi:\mathcal{M} \rightarrow \Re$. TMLE works by finding a suitable estimator $\hat{P}_n$ of $P_0$ and then simply substituting it into $\Psi$. Fortunately, it is often non necessary to come up with an estimator of the joint distribution $P_0$, instead only subparts of it are required. Those are called nuisance estimands because they are not of direct interest and for our purposes, only two nuisance estimands are necessary: - -```math -Q_0(t, w, c) = \mathbb{E}[Y|T=t, W=w, C=c] -``` - -and - -```math -G_0(t, w) = P(T=t|W=w) -``` - -Specifying which machine-learning method to use for each of those nuisance estimand is the only ingredient you will need to run a targeted estimation for any of the following quantities. Those quantities have a causal interpretation under suitable assumptions that we do not discuss here. - -### The Conditional Mean - -This is simply the expected value of the outcome $Y$ when the treatment is set to $t$. - -```math -CM_t(P) = \mathbb{E}[\mathbb{E}[Y|T=t, W]] -``` - -### The Average Treatment Effect (ATE) - -Probably the most famous quantity in causal inference. Under suitable assumptions, tt represents the additive effect of changing the treatment value from $t_1$ to $t_2$. - -```math -ATE_{t_1 \rightarrow t_2}(P) = \mathbb{E}[\mathbb{E}[Y|T=t_2, W]] - \mathbb{E}[\mathbb{E}[Y|T=t_1, W]] -``` - -### Interaction Average Treatment Effect (IATE) - -If you are interested in the interaction effect of multiple treatments, the IATE is for you. An example formula is displayed below for two interacting treatments whose values are both changed from $0$ to $1$: - -```math -IATE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}[\mathbb{E}[Y|T_1=1, T_2=1, W]] - \mathbb{E}[\mathbb{E}[Y|T_1=1, T_2=0, W]] \\ -- \mathbb{E}[\mathbb{E}[Y|T_1=0, T_2=1, W]] + \mathbb{E}[\mathbb{E}[Y|T_1=0, T_2=0, W]] -``` - -### Any function of the previous Estimands - -As a result of Julia's automatic differentiation facilities, given a set of already estimated estimands $(\Psi_1, ..., \Psi_k)$, we can automatically compute an estimator for $f(\Psi_1, ..., \Psi_k)$. +TMLE.jl is a Julia implementation of the Targeted Minimum Loss-Based Estimation ([TMLE](https://link.springer.com/book/10.1007/978-1-4419-9782-1)) framework. If you are interested in efficient and unbiased estimation of causal effects, you are in the right place. Since TMLE uses machine-learning methods to estimate nuisance estimands, the present package is based upon [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/). ## Installation @@ -86,11 +20,11 @@ Pkg> add TMLE To run an estimation procedure, we need 3 ingredients: -- A dataset -- A estimand of interest -- A nuisance estimands specification +1. A Structural Causal Model that describes the relationship between the variables. +2. An estimand of interest +3. A dataset -For instance, assume the following simple data generating process: +For illustration, assume we know the actual data generating process is as follows: ```math \begin{aligned} @@ -100,33 +34,21 @@ Y &\sim \mathcal{Normal}(1 + 3 \cdot T - T \cdot W, 0.01) \end{aligned} ``` -which can be simulated in Julia by: +Let's define the associated Structural Causal Model: ```@example quick-start -using Distributions -using StableRNGs -using Random -using CategoricalArrays -using MLJLinearModels using TMLE -using LogExpFunctions -rng = StableRNG(123) -n = 100 -W = rand(rng, Uniform(), n) -T = rand(rng, Uniform(), n) .< logistic.(1 .- 2W) -Y = 1 .+ 3T .- T.*W .+ rand(rng, Normal(0, 0.01), n) -dataset = (Y=Y, T=categorical(T), W=W) -nothing # hide +scm = StaticConfoundedModel(:Y, :T, :W) ``` -And say we are interested in the $ATE_{0 \rightarrow 1}(P_0)$: +Now, say we are interested in the Average Treatment Effect of the treatment ``T`` on the outcome ``Y``: $ATE_{T:0 \rightarrow 1}(Y)$: ```@example quick-start Ψ = ATE( + scm, outcome = :Y, treatment = (T=(case=true, control = false),), - confounders = [:W] ) nothing # hide ``` @@ -137,23 +59,34 @@ Note that in this example the ATE can be computed exactly and is given by: ATE_{0 \rightarrow 1}(P_0) = \mathbb{E}[1 + 3 - W] - \mathbb{E}[1] = 3 - \mathbb{E}[W] = 2.5 ``` -We next need to define the strategy for learning the nuisance estimands $Q$ and $G$, here we keep things simple and simply use generalized linear models: +Because we know the data generating process, we can simulate some data accordingly: ```@example quick-start -η_spec = NuisanceSpec( - LinearRegressor(), - LogisticClassifier(lambda=0) -) +using Distributions +using StableRNGs +using Random +using CategoricalArrays +using MLJLinearModels +using LogExpFunctions + +rng = StableRNG(123) +n = 100 +W = rand(rng, Uniform(), n) +T = rand(rng, Uniform(), n) .< logistic.(1 .- 2W) +Y = 1 .+ 3T .- T.*W .+ rand(rng, Normal(0, 0.01), n) +dataset = (Y=Y, T=categorical(T), W=W) nothing # hide ``` -We are now ready to run the TMLE procedure and look the associated confidence interval: +Running the `tmle` will produce two asymptotically linear estimators: the TMLE and the One Step Estimator. For each we can look at the associated estimate, confidence interval and p-value: ```@example quick-start -result, _ = tmle(Ψ, η_spec, dataset) +result, _ = tmle(Ψ, dataset) result ``` +and be comforted to see that our estimators covers the ground truth! 🥳 + ```@example quick-start using Test # hide @test pvalue(OneSampleTTest(result.tmle, 2.5)) > 0.05 # hide diff --git a/docs/src/resources.md b/docs/src/resources.md new file mode 100644 index 00000000..b18c30dd --- /dev/null +++ b/docs/src/resources.md @@ -0,0 +1,18 @@ +# Resources + +Targeted Learning is a difficult topic, while it is not strictly necessary to understand the details to use this package it can certainly help. Here is an incomplete list of external ressources that I found useful in my personnal search for enlightenment. + +## Websites + +These are two very clear introductions to causal inference and semi-parametric estimation: + +- [Introduction to Modern Causal Inference](https://alejandroschuler.github.io/mci/) (Alejandro Schuler, Mark J. van der Laan). +- [A Ride in Targeted Learning Territory](https://achambaz.github.io/tlride/) (David Benkeser, Antoine Chambaz). + +## Text Books + +- [Targeted Learning](https://link.springer.com/book/10.1007/978-1-4419-9782-1) (Mark J. van der Laan, Sherri Rose). + +## Journal articles + +- [Semiparametric doubly robust targeted double machine learning: a review](https://arxiv.org/abs/2203.06469) (Edward H. Kennedy). diff --git a/docs/src/user_guide.md b/docs/src/user_guide.md deleted file mode 100644 index 7e73984d..00000000 --- a/docs/src/user_guide.md +++ /dev/null @@ -1,324 +0,0 @@ -# User Guide - -```@meta -CurrentModule = TMLE -``` - -## The Dataset - -TMLE.jl should be compatible with any dataset respecting the [Tables.jl](https://tables.juliadata.org/stable/) interface, that is, a structure like a `NamedTuple` or a `DataFrame` from [DataFrames.jl](https://dataframes.juliadata.org/stable/) should work. In the remainder of this section, we will be working with the same dataset and see that we can ask very many questions (Estimands) from it. - -```@example user-guide -using Random -using Distributions -using DataFrames -using StableRNGs -using CategoricalArrays -using TMLE -using LogExpFunctions - -function make_dataset(;n=1000) - rng = StableRNG(123) - # Confounders - W₁ = rand(rng, Uniform(), n) - W₂ = rand(rng, Uniform(), n) - # Covariates - C₁ = rand(rng, Uniform(), n) - # Treatment | Confounders - T₁ = rand(rng, Uniform(), n) .< logistic.(0.5sin.(W₁) .- 1.5W₂) - T₂ = rand(rng, Uniform(), n) .< logistic.(-3W₁ - 1.5W₂) - # Target | Confounders, Covariates, Treatments - Y = 1 .+ 2W₁ .+ 3W₂ .- 4C₁.*T₁ .- 2T₂.*T₁.*W₂ .+ rand(rng, Normal(0, 0.1), n) - return DataFrame( - W₁ = W₁, - W₂ = W₂, - C₁ = C₁, - T₁ = categorical(T₁), - T₂ = categorical(T₂), - Y = Y - ) -end -dataset = make_dataset() -nothing # hide -``` - -!!! note "Note on Treatment variables" - It should be noted that the treatment variables **must** be [categorical](https://categoricalarrays.juliadata.org/stable/). Since the treatment is also used as an input to the ``Q_0`` learner, a `OneHotEncoder` is used by default (see [The Nuisance Estimands](@ref) section). If a numerical representation is more appropriate (ordinal variables), use the keyword `ordered=true` when constructing a `categorical` vector, the `OneHotEncoder` will ignore those variables and their floating point representation will be used. - -## The Nuisance Estimands - -As described in the [Mathematical setting](@ref) section, we need to provide an estimation strategy for both $Q_0$ and $G_0$. For illustration purposes, we here consider a simple strategy where both models are assumed to be generalized linear models. However this is not the recommended practice since there is little chance those functions are actually linear, and theoretical guarantees associated with TMLE may fail to hold. We recommend instead the use of Super Learning which is exemplified in [The benefits of Super Learning](@ref). - -```@example user-guide -using MLJLinearModels - -η_spec = NuisanceSpec( - LinearRegressor(), # Q model - LogisticClassifier(lambda=0) # G model -) -nothing # hide -``` - -!!! note "Practical note on $Q_0$ and $G_0$" - The models chosen for the nuisance estimands should be adapted to the outcome they target: - - ``Q_0 = \mathbf{E}_0[Y|T=t, W=w, C=c]``, $Y$ can be either continuous or categorical, in our example it is continuous and a `LinearRegressor` is a correct choice. - - ``G_0 = P_0(T|W)``, $T$ are always categorical variables. If T is a single treatment with only 2 levels, a logistic regression will work, if T has more than two levels a multinomial regression for instance would be suitable. If there are more than 2 treatment variables (with potentially more than 2 levels), then, the joint distribution is learnt and a multinomial regression would also work. In any case, the `LogisticClassifier` from [MLJLinearModels](https://juliaai.github.io/MLJLinearModels.jl/stable/) is a suitable choice. - For more information on available models and their uses, we refer to the [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) documentation - -The `NuisanceSpec` struct also holds a specification for the `OneHotEncoder` necessary for the encoding of treatment variables and a generalized linear model for the fluctuation model. Unless you know what you are doing, there is little chance you need to modify those. - -## Estimands - -### The Conditional mean - -We are now ready to move to the definition of the estimands of interest. The most basic type of estimand is the conditional mean of the outcome given the treatment: - -```math -CM_t(P) = \mathbb{E}[\mathbb{E}[Y|T=t, W]] -``` - -The treatment does not have to be restricted to a single variable, we can define for instance $CM_{T_1=1, T_2=1}$: - -```@example user-guide -Ψ = CM( - outcome = :Y, - treatment = (T₁=true, T₂=true), - confounders = [:W₁, :W₂] -) -nothing # hide -``` - -In this case, we can compute the exact value of the estimand: - -```math -CM_{T_1=1, T_2=1} = 1 + 2\mathbb{E}[W₁] + 3\mathbb{E}[W₂] - 4\mathbb{E}[C₁] - 2\mathbb{E}[W₂] = 0.5 -``` - -Running a targeted estimation procedure should then be as simple as: - -```@example user-guide -cm_result₁₁, _ = tmle(Ψ, η_spec, dataset, verbosity=0) -nothing # hide -``` - -For now, let's ignore the `_` output and focus on the `TMLEResult`. It contains estimates corresponding to 3 distinct estimators of $CM_{T_1=1, T_2=0}$. - -- The `tmle` field contains the Targeted Minimum Loss-based estimate and associated influence curve. -- The `onestep` field contains the One-Step estimate and associated influence curve. -- The `initial` field contains the initial or naive estimate. - -For both the `tmle` and `onestep`, we can have a look at the value and variance of the estimator, since the estimator is asymptotically normal, a 95% confidence interval can be rougly constructed via: - -```@example user-guide -Ψ̂ = TMLE.estimate(cm_result₁₁.tmle) -σ² = var(cm_result₁₁.tmle) -Ψ̂ - 1.96√σ² <= 0.5 <= Ψ̂ + 1.96√σ² -``` - -In fact, we can easily be more rigorous here and perform a standard T test: - -```@example user-guide -OneSampleTTest(cm_result₁₁.tmle) -``` - -### The Average Treatment Effect - -Let's now turn our attention to the Average Treatment Effect: - -```math -ATE_{t_1 \rightarrow t_2}(P) = \mathbb{E}[\mathbb{E}[Y|T=t_2, W]] - \mathbb{E}[\mathbb{E}[Y|T=t_1, W]] -``` - -Again, from our dataset, there are many ATEs we may be interested in, let's assume we are interested in ``ATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1}``. Since we know the generating process, this can be computed exactly and we have: - -```math -\begin{aligned} -ATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1} &= (1 + 2\mathbb{E}[W₁] + 3\mathbb{E}[W₂] - 4\mathbb{E}[C₁] - 2\mathbb{E}[W₂]) - (1 + 2\mathbb{E}[W₁] + 3\mathbb{E}[W₂]) \\ - &= -3 -\end{aligned} -``` - -Let's see what the TMLE tells us: - -```@example user-guide -Ψ = ATE( - outcome = :Y, - treatment = (T₁=(case=true, control=false), T₂=(case=true, control=false)), - confounders = [:W₁, :W₂] -) - -ate_result, _ = tmle(Ψ, η_spec, dataset, verbosity=0) - -OneSampleTTest(ate_result.tmle) -``` - -As expected. - -### The Interaction Average Treatment Effect - -Finally let us look at the most interesting case of interactions, we compute here the ``IATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1}`` since this is the highest order in our data generating process. That being said, you could go after any higher-order (3, 4, ...) interaction if you wanted. However you will inevitably decrease power and encounter positivity violations as you climb the interaction-order ladder. - -Let us provide the ground truth for this pairwise interaction, you can check that: - -```math -\begin{aligned} -IATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1} &= 1 .+ 2W₁ .+ 3W₂ .- 4C₁.*T₁ .- 2T₂.*T₁.*W₂ \\ - &= - 1 -\end{aligned} -``` - -and run: - -```@example user-guide -Ψ = IATE( - outcome = :Y, - treatment = (T₁=(case=true, control=false), T₂=(case=true, control=false)), - confounders = [:W₁, :W₂] -) - -iate_result, _ = tmle(Ψ, η_spec, dataset, verbosity=0) - -OneSampleTTest(iate_result.tmle) -``` - -### Composing Estimands - -By leveraging the multivariate Central Limit Theorem and Julia's automatic differentiation facilities, we can actually compute any new estimand estimate from a set of already estimated estimands. By default, TMLE.jl will use [Zygote](https://fluxml.ai/Zygote.jl/latest/) but since we are using [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) you can change the backend to your favorite AD system. - -For instance, by definition of the ATE, we should be able to retrieve ``ATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1}`` by composing ``CM_{T_1=1, T_2=1} - CM_{T_1=0, T_2=0}``. We already have almost all of the pieces, we just need an estimate for ``CM_{T_1=0, T_2=0}``, let's get it. - -```@example user-guide -Ψ = CM( - outcome = :Y, - treatment = (T₁=false, T₂=false), - confounders = [:W₁, :W₂] -) -cm_result₀₀, _ = tmle(Ψ, η_spec, dataset, verbosity=0) -nothing # hide -``` - -```@example user-guide -composed_ate_result = compose(-, cm_result₁₁.tmle, cm_result₀₀.tmle) -nothing # hide -``` - -We can compare the estimate value, which is simply obtained by applying the function to the arguments: - -```@example user-guide -TMLE.estimate(composed_ate_result), TMLE.estimate(ate_result.tmle) -``` - -and the variance: - -```@example user-guide -var(composed_ate_result), var(ate_result.tmle) -``` - -### Reading Estimands from YAML files - -It may be useful to read and write estimands to text files. We provide this functionality using the YAML format and the `estimands_from_yaml` and `estimands_to_yaml` functions. Since YAML is quite sensitive to identation, it is not recommended to write those files by hand. As an illustration, an example of such a file is given below: - -```yaml -Estimands: - - outcome: Y1 - treatment: (T2 = (case = 1, control = 0), T1 = (case = 1, control = 0)) - confounders: - - W1 - - W2 - covariates: - - C1 - type: IATE - - outcome: Y3 - treatment: (T3 = (case = 1, control = 0), T1 = (case = "AC", control = "CC")) - confounders: - - W1 - covariates: [] - type: ATE - - outcome: Y3 - treatment: (T3 = "AC", T1 = "CC") - confounders: - - W1 - covariates: [] - type: CM -``` - -## Using the cache - -Oftentimes, we are interested in multiple estimands, or would like to investigate how our estimator is affected by changes in the nuisance estimands specification. In many cases, as long as the dataset under study is the same, it is possible to save some computational time by caching the previously learnt nuisance estimands. We describe below how TMLE.jl proposes to do that in some common scenarios. For that purpose let us add a new outcome variable (which is simply random noise) to our dataset: - -```@example user-guide -dataset.Ynew = rand(1000) -nothing # hide -``` - -### Scenario 1: Changing the treatment values - -Let us say we are interested in two ATE estimands: ``ATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1}`` and ``ATE_{T_1=1 \rightarrow 0, T_2=0 \rightarrow 1}`` (Notice how the setting for $T_1$ has changed). - -Let us start afresh an compute the first ATE: - -```@example user-guide -Ψ = ATE( - outcome = :Y, - treatment = (T₁=(case=true, control=false), T₂=(case=true, control=false)), - confounders = [:W₁, :W₂] -) - -ate_result₁, cache = tmle(Ψ, η_spec, dataset) -nothing # hide -``` - -Notice the logs are informing you of all the nuisance estimands that are being fitted. - -Let us now investigate the second ATE by using the cache: - -```@example user-guide -Ψ = ATE( - outcome = :Y, - treatment = (T₁=(case=false, control=true), T₂=(case=true, control=false)), - confounders = [:W₁, :W₂] -) - -ate_result₂, cache = tmle!(cache, Ψ) -nothing # hide -``` - -You should see that the logs are actually now telling you which nuisance estimands have been reused, i.e. all of them, only the targeting step needs to be done! This is because we already had nuisance estimators that matched our outcome estimand. - -### Scenario 2: Changing the outcome - -Let us now imagine that we are interested in another outcome: $Ynew$, we can say so by defining a new estimand and running the TMLE procedure using the cache: - -```@example user-guide -Ψ = ATE( - outcome = :Ynew, - treatment = (T₁=(case=true, control=false), T₂=(case=true, control=false)), - confounders = [:W₁, :W₂] -) - -ate_result₃, cache = tmle!(cache, Ψ) -nothing # hide -``` - -As you can see, only $Q$ has been updated because the existing cached $G$ already matches our target estimand and cane be reused. - -### Scenario 3: Changing the nuisance estimands specification - -Another common situation is to try a new model for a given nuisance estimand (or both). Here we can try a new regularization estimand for our logistic regression: - -```@example user-guide -η_spec = NuisanceSpec( - LinearRegressor(), # Q model - LogisticClassifier(lambda=0.001) # Updated G model -) - -ate_result₄, cache = tmle!(cache, η_spec) -nothing # hide -``` - -Since we have only updated $G$'s specification, only this model is fitted again. - -### General behaviour - -Any change to either the `Estimand` of interest or the `NuisanceSpec` structures will trigger an update of the cache. If you have a predefined -list of estimands to estimate, you can optimize the estimation order of those estimands by calling `optimize_ordering!` or `optimize_ordering`. This will order the estimands in order to save the number of required nuisance estimands fits. diff --git a/docs/src/user_guide/adjustment.md b/docs/src/user_guide/adjustment.md new file mode 100644 index 00000000..5bf7af90 --- /dev/null +++ b/docs/src/user_guide/adjustment.md @@ -0,0 +1,12 @@ +# Adjustment Methods + +In a `SCM`, each variable is determined by a set of parents and a statistical model describing the functional relationship between them. However, for the estimation of Causal Estimands the fitted models may not exactly correspond to the variable's equation in the `SCM`. Adjustment methods tell the estimation procedure which input variables should be incorporated in the statistical model fits. + +## Backdoor Adjustment + +At the moment we provide a single adjustment method, namely the Backdoor adjustment method. The adjustment set consists of all the treatment variable's parents. Additional covariates used to fit the outcome model can be provided via `outcome_extra`. + +```@example +BackdoorAdjustment(;outcome_extra=[:C]) +nothing +``` diff --git a/docs/src/user_guide/estimands.md b/docs/src/user_guide/estimands.md new file mode 100644 index 00000000..a4d319da --- /dev/null +++ b/docs/src/user_guide/estimands.md @@ -0,0 +1,106 @@ +# Estimands + +Most causal questions can be translated into a causal estimand. Usually either an interventional or counterfactual quantity. What would have been the outcome if I had set this variable to this value? When identified, this causal estimand translates to a statistical estimand which can be estimated from data. For us, an estimand will be a functional, that is a function that takes as input a probability distribution, and outputs a real number or vector of real numbers. + +Mathematically speaking, denoting the estimand by ``\Psi``, the set of all probability distributions by ``\mathcal{M}``: + +```math +\Psi: \mathcal{M} \rightarrow \mathbb{R}^p +``` + +At the moment, most of the work in this package has been focused on estimands that are composite functions of the interventional conditional mean which is easily identified via backdoor adjustment and for which the efficient influence function is well known. + +In what follows, ``P`` is a probability distribution generating an outcome ``Y``, a random vector of "treatment" variables ``\textbf{T}`` and a random vector of "confounding" variables ``\textbf{W}``. + +## The Interventional Conditional Mean (CM) + +- Causal Question: + +What would be the mean of ``Y`` in the population if we intervened on ``\textbf{T}`` and set it to ``\textbf{t}``? + +- Causal Estimand: + +```math +CM_{\textbf{t}}(P) = \mathbb{E}[Y|do(\textbf{T}=\textbf{t})] +``` + +- Statistical Estimand (via backdoor adjustment): + +```math +CM_{\textbf{t}}(P) = \mathbb{E}_{\textbf{W}}[\mathbb{E}[Y|do(\textbf{T}=\textbf{t}), \textbf{W}]] +``` + +- TMLE.jl Example + +For a Structural Causal Model `scm`, an outcome `Y` and two treatments `T₁` and `T₂` taking values `t₁` and `t₂` respectively: + +```julia +Ψ = CM(scm, outcome=:Y, treatment=(T₁=t₁, T₂=t₂)) +``` + +## The Average Treatment Effect + +- Causal Question: + +What is the average difference in treatment effect on ``Y`` when the two treatment levels are set to ``\textbf{t}_1`` and ``\textbf{t}_2`` respectively? + +- Causal Estimand: + +```math +ATE_{\textbf{t}_1 \rightarrow \textbf{t}_2}(P) = \mathbb{E}[Y|do(\textbf{T}=\textbf{t}_2)] - \mathbb{E}[Y|do(\textbf{T}=\textbf{t}_1)] +``` + +- Statistical Estimand (via backdoor adjustment): + +```math +\begin{aligned} +ATE_{\textbf{t}_1 \rightarrow \textbf{t}_2}(P) &= CM_{\textbf{t}_2}(P) - CM_{\textbf{t}_1}(P) \\ +&= \mathbb{E}_{\textbf{W}}[\mathbb{E}[Y|\textbf{T}=\textbf{t}_2, \textbf{W}] - \mathbb{E}[Y|\textbf{T}=\textbf{t}_1, \textbf{W}]] +\end{aligned} +``` + +- TMLE.jl Example + +For a Structural Causal Model `scm`, an outcome `Y` and two treatments differences `T₁`:`t₁₁ → t₁₂` and `T₂`:`t₂₁ → t₂₂`: + +```julia +Ψ = ATE(scm, outcome=:Y, treatment=(T₁=(case=t₁₂, control=t₁₁), T₂=(case=t₂₂, control=t₂₁))) +``` + +Note that all treatments are not required to change, for instance the following where `T₁` is held fixed at `t₁₁` is also a valid `ATE`: + +```julia +Ψ = ATE(scm, outcome=:Y, treatment=(T₁=(case=t₁₁, control=t₁₁), T₂=(case=t₂₂, control=t₂₁))) +``` + +## The Interaction Average Treatment Effect + +- Causal Question: + +Interactions can be defined up to any order but we restrict the interpretation to two variables. Is the Total Average Treatment Effect of ``T₁`` and ``T₂`` different from the sum of their respective marginal Average Treatment Effects? Is there a synergistic effect between ``T₁`` and ``T₂`` on ``Y``. + +For a general higher-order definition, please refer to [Higher-order interactions in statistical physics and machine learning: A model-independent solution to the inverse problem at equilibrium](https://arxiv.org/abs/2006.06010). + +- Causal Estimand: + +For two points interaction with both treatment and control levels ``0`` and ``1`` for ease of notation: + +```math +IATE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}[Y|do(T_1=1, T_2=1)] - \mathbb{E}[Y|do(T_1=1, T_2=0)] \\ +- \mathbb{E}[Y|do(T_1=0, T_2=1)] + \mathbb{E}[Y|do(T_1=0, T_2=0)] +``` + +- Statistical Estimand (via backdoor adjustment): + +```math +IATE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}_{\textbf{W}}[\mathbb{E}[Y|T_1=1, T_2=1, \textbf{W}]] - \mathbb{E}[Y|T_1=1, T_2=0, \textbf{W}] \\ +- \mathbb{E}[Y|T_1=0, T_2=1, \textbf{W}] + \mathbb{E}[Y|T_1=0, T_2=0, \textbf{W}]] +``` + +- TMLE.jl Example + +For a Structural Causal Model `scm`, an outcome `Y` and two treatments differences `T₁`:`t₁₁ → t₁₂` and `T₂`:`t₂₁ → t₂₂`: + +```julia +Ψ = IATE(scm, outcome=:Y, treatment=(T₁=(case=t₁₂, control=t₁₁), T₂=(case=t₂₂, control=t₂₁))) +``` diff --git a/docs/src/user_guide/estimation.md b/docs/src/user_guide/estimation.md new file mode 100644 index 00000000..e2e4d4cc --- /dev/null +++ b/docs/src/user_guide/estimation.md @@ -0,0 +1,158 @@ +# Estimation + +## Estimating a single Estimand + +```@setup estimation +using Random +using Distributions +using DataFrames +using StableRNGs +using CategoricalArrays +using TMLE +using LogExpFunctions + +function make_dataset(;n=1000) + rng = StableRNG(123) + # Confounders + W₁₁= rand(rng, Uniform(), n) + W₁₂ = rand(rng, Uniform(), n) + W₂₁= rand(rng, Uniform(), n) + W₂₂ = rand(rng, Uniform(), n) + # Covariates + C = rand(rng, Uniform(), n) + # Treatment | Confounders + T₁ = rand(rng, Uniform(), n) .< logistic.(0.5sin.(W₁₁) .- 1.5W₁₂) + T₂ = rand(rng, Uniform(), n) .< logistic.(-3W₂₁ - 1.5W₂₂) + # Target | Confounders, Covariates, Treatments + Y = 1 .+ 2W₂₁ .+ 3W₂₂ .+ W₁₁ .- 4C.*T₁ .- 2T₂.*T₁.*W₁₂ .+ rand(rng, Normal(0, 0.1), n) + return DataFrame( + W₁₁ = W₁₁, + W₁₂ = W₁₂, + W₂₁ = W₂₁, + W₂₂ = W₂₂, + C = C, + T₁ = categorical(T₁), + T₂ = categorical(T₂), + Y = Y + ) +end +dataset = make_dataset() +scm = SCM( + SE(:Y, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], with_encoder(LinearRegressor())), + SE(:T₁, [:W₁₁, :W₁₂], LogisticClassifier()), + SE(:T₂, [:W₂₁, :W₂₂], LogisticClassifier()), +) +``` + +Once a `SCM` and an estimand have been defined, we can proceed with Targeted Estimation. This is done via the `tmle` function. Drawing from the example dataset and `SCM` from the Walk Through section, we can estimate the ATE for `T₁`. + +```@example estimation +Ψ₁ = ATE(scm, outcome=:Y, treatment=(T₁=(case=true, control=false),)) +result, fluctuation_mach = tmle(Ψ₁, dataset; + adjustment_method=BackdoorAdjustment([:C]), + verbosity=1, + force=false, + threshold=1e-8, + weighted_fluctuation=false +) +``` + +We see that both models corresponding to variables `Y` and `T₁` were fitted in the process but that the model for `T₂` was not because it was not necessary to estimate this estimand. + +The `fluctuation_mach` corresponds to the fitted machine that was used to fluctuate the initial fit. For instance, we can see what is the value of ``\epsilon`` corresponding to the clever covariate. + +```@example estimation +ϵ = fitted_params(fluctuation_mach).coef[1] +``` + +The `result` corresponds to the estimation result and contains 3 main elements: + +- The `TMLEEstimate` than can be accessed via: `tmle(result)`. +- The `OSEstimate` than can be accessed via: `ose(result)`. +- The naive initial estimate. + +Since both the TMLE and OSE are asymptotically linear estimators, standard T tests from [HypothesisTests.jl](https://juliastats.org/HypothesisTests.jl/stable/) can be performed for each of them. + +```@example estimation +tmle_test_result = OneSampleTTest(tmle(result)) +``` + +We could now get an interest in the Average Treatment Effect of `T₂`: + +```@example estimation +Ψ₂ = ATE(scm, outcome=:Y, treatment=(T₂=(case=true, control=false),)) +result, fluctuation_mach = tmle(Ψ₂, dataset; + adjustment_method=BackdoorAdjustment([:C]), + verbosity=1, + force=false, + threshold=1e-8, + weighted_fluctuation=false +) +``` + +The model for `T₂` was fitted in the process but so was the model for `Y` 🤔. This is because the `BackdoorAdjustment` method determined that the set of inputs for `Y` were different in both cases. + +## Reusing the SCM + +Let's now see how the models can be reused with a new estimand, say the Total Average Treatment Effecto of both `T₁` and `T₂`. + +```@example estimation +Ψ₃ = ATE(scm, outcome=:Y, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false))) +result, fluctuation_mach = tmle(Ψ₃, dataset; + adjustment_method=BackdoorAdjustment([:C]), + verbosity=1, + force=false, + threshold=1e-8, + weighted_fluctuation=false +) +``` + +This time only the statistical model for `Y` is fitted again while reusing the models for `T₁` and `T₂`. Finally, let's see what happens if we estimate the `IATE` between `T₁` and `T₂`. + +```@example estimation +Ψ₄ = IATE(scm, outcome=:Y, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false))) +result, fluctuation_mach = tmle(Ψ₄, dataset; + adjustment_method=BackdoorAdjustment([:C]), + verbosity=1, + force=false, + threshold=1e-8, + weighted_fluctuation=false +) +``` + +All statistical models have been reused 😊! + +## Ordering the estimands + +Given a vector of estimands, a clever ordering can be obtained via the `optimize_ordering/optimize_ordering!` functions. + +```@example estimation +optimize_ordering([Ψ₃, Ψ₁, Ψ₂, Ψ₄]) == [Ψ₁, Ψ₃, Ψ₄, Ψ₂] +``` + +## Composing Estimands + +By leveraging the multivariate Central Limit Theorem and Julia's automatic differentiation facilities, we can estimate any estimand which is a function of already estimated estimands. By default, TMLE.jl will use [Zygote](https://fluxml.ai/Zygote.jl/latest/) but since we are using [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) you can change the backend to your favorite AD system. + +For instance, by definition of the ATE, we should be able to retrieve ``ATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1}`` by composing ``CM_{T_1=1, T_2=1} - CM_{T_1=0, T_2=0}``. We already have almost all of the pieces, we just need an estimate for ``CM_{T_1=0, T_2=0}``, let's get it. + +```@example estimation +Ψ = CM( + outcome = :Y, + treatment = (T₁=false, T₂=false), + confounders = [:W₁, :W₂] +) +cm_result₀₀, _ = tmle(Ψ, η_spec, dataset, verbosity=0) +nothing # hide +``` + +```@example estimation +composed_ate_result = compose(-, cm_result₁₁.tmle, cm_result₀₀.tmle) +nothing # hide +``` + +## Weighted Fluctuation + +It has been reported that, in settings close to positivity violation (some treatments' values are very rare) TMLE may be unstable. This has been shown to be stabilized by fitting a weighted fluctuation model instead and by slightly modifying the clever covariate to keep things mathematically sound. + +This is implemented in TMLE.jl and can be turned on by selecting `weighted_fluctuation=true` in the `tmle` function. diff --git a/docs/src/user_guide/scm.md b/docs/src/user_guide/scm.md new file mode 100644 index 00000000..66d95057 --- /dev/null +++ b/docs/src/user_guide/scm.md @@ -0,0 +1,90 @@ +# Structural Causal Models + +In TMLE.jl, everything starts from the definition of a Structural Causal Model (`SCM`). A `SCM` in a series of Structural Equations (`SE`) that describe the causal relationships between the random variables under study. The purpose of this package is not to infer the `SCM`, instead we assume it is provided by the user. There are multiple ways one can define a `SCM` that we now describe. + +## Incremental Construction + +All models are wrong? Well maybe not the following: + +```@example scm-incremental +scm = SCM() +``` + +This model does not say anything about the random variables and is thus not really useful. Let's assume that we are interested in an outcome ``Y`` and that this outcome is determined by 8 other random variables. We can add this assumption to the model + +```@example scm-incremental +push!(scm, SE(:Y, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :W, :C])) +``` + +At this point, we haven't made any assumption regarding the functional form of the relationship between ``Y`` and its parents. We can add a further assumption by setting a statistical model for ``Y``, suppose we know it is generated from a logistic model, we can make that explicit: + +```@example scm-incremental +setmodel!(scm.Y, LogisticClassifier()) +``` + +--- +ℹ️ **Note on Models** + +- TMLE.jl is based on the main machine-learning framework in Julia: [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/). As such, any model respecting the MLJ interface is a valid model in TMLE.jl. +- In real world scenarios, we usually don't know what is the true statistical model for each variable and want to keep it as large as possible. For this reason it is recommended to use Super-Learning which is implemented in MLJ by the [Stack](https://alan-turing-institute.github.io/MLJ.jl/dev/model_stacking/#Model-Stacking) and comes with theoretical properties. +- In the dataset, treatment variables are represented with [categorical data](https://alan-turing-institute.github.io/MLJ.jl/dev/working_with_categorical_data/). This means the models that depend on such variables will need to properly deal with them. For this purpose we provide a [TreatmentTransformer](@ref) which can easily be combined with any `model` in a [Pipelining](https://alan-turing-institute.github.io/MLJ.jl/dev/linear_pipelines/) flavour with `with_encoder(model)`. +- The `SCM` has no knowledge of the data and thus cannot verify that the assumed statistical model is compatible with the data. This is done at a later stage. + +--- + +Let's now assume that we have a more complete knowledge of the problem and we also know how `T₁` and `T₂` depend on the rest of the variables in the system. + +```@example scm-incremental +push!(scm, SE(:T₁, [:W₁₁, :W₁₂, :W], model=LogisticClassifier())) +push!(scm, SE(:T₂, [:W₂₁, :W₂₂, :W])) +``` + +## One Step Construction + +Instead of constructing the `SCM` incrementally, one can provide all the specified equations at once: + +```@example scm-one-step +scm = SCM( + SE(:Y, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :W, :C], with_encoder(LinearRegressor())), + SE(:T₁, [:W₁₁, :W₁₂, :W], model=LogisticClassifier()), + SE(:T₂, [:W₂₁, :W₂₂, :W]), +) +``` + +Noting that we have used the `with_encoder` function to reflect the fact that we know that `T₁` and `T₂` are categorical variables. + +## Classic Structural Causal Models + +There are many cases where we are interested in estimating the causal effect of a single treatment variable on a single outcome. Because it is typically only necessary to adjust for backdoor variables in order to identify this causal effect, we provide the `StaticConfoundedModel` interface to build such `SCM`: + +```@example static-scm-1 +scm = StaticConfoundedModel( + :Y, :T, [:W₁, :W₂]; + covariates=[:C], + outcome_model = with_encoder(LinearRegressor()), + treatment_model = LogisticClassifier() +) +``` + +The optional `covariates` are variables that influence the outcome but are not confounding the treatment. + +This model can be extended to a plate-model with multiple treatments and multiple outcomes. In this case the set of confounders is assumed to confound all treatments which are in turn assumed to impact all outcomes. This can be defined as: + +```@example static-scm-2 +scm = StaticConfoundedModel( + [:Y₁, :Y₂], [:T₁, :T₂], [:W₁, :W₂]; + covariates=[:C], + outcome_model = with_encoder(LinearRegressor()), + treatment_model = LogisticClassifier() +) +``` + +More classic `SCM` may be added in the future based on needs. + +## Fitting the SCM + +It is usually not necessary to fit an entire `SCM` in order to estimate causal estimands of interest. Instead only some components are required and will be automatically determined (see [Estimation](@ref)). However, if you like, you can fit all the equations for which statistical models have been provided against a dataset: + +```julia +fit!(scm, dataset) +``` diff --git a/docs/src/walk_through.md b/docs/src/walk_through.md new file mode 100644 index 00000000..8686cb40 --- /dev/null +++ b/docs/src/walk_through.md @@ -0,0 +1,151 @@ +# Walk Through + +The goal of this section is to provide a comprehensive (but non-exhaustive) illustration of the estimation process provided in TMLE.jl. For an in-depth explanation, please refer to the User Guide. + +## The Dataset + +TMLE.jl is compatible with any dataset respecting the [Tables.jl](https://tables.juliadata.org/stable/) interface, that is for instance, a `NamedTuple`, a `DataFrame`, an `Arrow.Table` etc... In this section, we will be working with the same dataset all along. + +⚠️ One thing to note is that treatment variables as well as binary outcomes **must** be encoded as `categorical` variables in the dataset (see [MLJ Working with categorical data](https://alan-turing-institute.github.io/MLJ.jl/dev/working_with_categorical_data/)). + +The dataset is generated as follows: + +```@example user-guide +using Random +using Distributions +using DataFrames +using StableRNGs +using CategoricalArrays +using TMLE +using LogExpFunctions + +function make_dataset(;n=1000) + rng = StableRNG(123) + # Confounders + W₁₁= rand(rng, Uniform(), n) + W₁₂ = rand(rng, Uniform(), n) + W₂₁= rand(rng, Uniform(), n) + W₂₂ = rand(rng, Uniform(), n) + # Covariates + C = rand(rng, Uniform(), n) + # Treatment | Confounders + T₁ = rand(rng, Uniform(), n) .< logistic.(0.5sin.(W₁₁) .- 1.5W₁₂) + T₂ = rand(rng, Uniform(), n) .< logistic.(-3W₂₁ - 1.5W₂₂) + # Target | Confounders, Covariates, Treatments + Y = 1 .+ 2W₂₁ .+ 3W₂₂ .+ W₁₁ .- 4C.*T₁ .- 2T₂.*T₁.*W₁₂ .+ rand(rng, Normal(0, 0.1), n) + return DataFrame( + W₁₁ = W₁₁, + W₁₂ = W₁₂, + W₂₁ = W₂₁, + W₂₂ = W₂₂, + C = C, + T₁ = categorical(T₁), + T₂ = categorical(T₂), + Y = Y + ) +end +dataset = make_dataset() +nothing # hide +``` + +Even though the role of a variable (treatment, outcome, confounder, ...) is relative to the problem setting, this dataset can intuitively be decomposed into: + +- 1 Outcome variable (``Y``). +- 2 Treatment variables ``(T₁, T₂)`` with confounders ``(W₁₁, W₁₂)`` and ``(W₂₁, W₂₂)`` respectively. +- 1 Extra Covariate variable (``C``). + +## The Structural Causal Model + +The modeling stage starts from the definition of a Structural Causal Model (`SCM`). This is simply a list of Structural Equations (`SE`) describing the relationships between the random variables associated with our problem. See [Structural Causal Models](@ref) for an in-depth explanation. For our purposes, we will simply define it as follows: + +```@example user-guide +scm = SCM( + SE(:Y, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], with_encoder(LinearRegressor())), + SE(:T₁, [:W₁₁, :W₁₂], LogisticClassifier()), + SE(:T₂, [:W₂₁, :W₂₂], LogisticClassifier()), +) +``` + +--- +**NOTE** + +- Each Structural Equation specifies a child node, its parents and the assumed relationship between them. Here we know the model class from which each variable has been generated but in practice this is usually not the case. Instead we recommend the use of Super Learning / Stacking (see TODO). +- Because the treatment variables are categorical, they need to be encoded to be digested by the downstream models, `with_encoder(model)` simply creates a [Pipeline](https://alan-turing-institute.github.io/MLJ.jl/dev/linear_pipelines/#Linear-Pipelines) equivalent to `TreatmentEncoder() |> model`. +- At this point, the `SCM` has no knowledge about the dataset and cannot verify that the models are compatible with the actual data. + +--- + +## The Estimands + +From the previous causal model we can ask multiple causal questions which are each represented by a distinct estimand. The set of available estimands types can be listed as follow: + +```@example user-guide +AVAILABLE_ESTIMANDS +``` + +At the moment there are 3 main estimand types we can estimate in TMLE.jl, we provide below a few examples. + +- The Interventional Conditional Mean (see: TODO): + +```@example user-guide +cm = CM( + scm, + outcome=:Y, + treatment=(T₁=1,) + ) +``` + +- The Average Treatment Effect (see: TODO): + +```@example user-guide +ate = ATE( + scm, + outcome=:Y, + treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)) +) +marginal_ate_t1 = ATE( + scm, + outcome=:Y, + treatment=(T₁=(case=1, control=0),) +) +``` + +- The Interaction Average Treatment Effect (see: TODO): + +```@example user-guide +iate = IATE( + scm, + outcome=:Y, + treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)) +) +``` + +## Targeted Estimation + +Then each parameter can be estimated by calling the `tmle` function. For example: + +```@example user-guide +result, _ = tmle(cm, dataset) +result +``` + +The `result` contains 3 main elements: + +- The `TMLEEstimate` than can be accessed via: `tmle(result)`. +- The `OSEstimate` than can be accessed via: `ose(result)`. +- The naive initial estimate. + +The adjustment set is determined by the provided `adjustment_method` keyword. At the moment, only `BackdoorAdjustment` is available. However one can specify that extra covariates could be used to fit the outcome model. + +```@example user-guide +result, _ = tmle(iate, dataset;adjustment_method=BackdoorAdjustment([:C])) +result +``` + +## Hypothesis Testing + +Because the TMLE and OSE are asymptotically linear estimators, they asymptotically follow a Normal distribution. This means one can perform standard T tests of null hypothesis. TMLE.jl extends the method provided by the [HypothesisTests.jl](https://juliastats.org/HypothesisTests.jl/stable/) package that can be used as follows. + +```@example user-guide +OneSampleTTest(tmle(result)) +``` diff --git a/src/TMLE.jl b/src/TMLE.jl index b6a1e1c6..15c52709 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -24,14 +24,16 @@ import AbstractDifferentiation as AD # ############################################################################# export SE, StructuralEquation -export StructuralCausalModel, SCM, StaticConfoundedModel +export StructuralCausalModel, SCM, StaticConfoundedModel +export setmodel!, equations, reset!, parents export CM, ATE, IATE, AVAILABLE_ESTIMANDS -export parents, fit!, reset!, equations, optimize_ordering, optimize_ordering! +export fit!, optimize_ordering, optimize_ordering! export tmle export var, estimate, initial_estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose -export TreatmentTransformer +export TreatmentTransformer, with_encoder export BackdoorAdjustment +export DAG, graphplot # ############################################################################# # INCLUDES @@ -45,7 +47,6 @@ include("utils.jl") include("estimation.jl") include("estimate.jl") - # ############################################################################# # PRECOMPILATION WORKLOAD # ############################################################################# @@ -66,7 +67,7 @@ function run_precompile_workload() Y₁ = μY .+ rand(n) Y₂ = categorical(rand(n) .< logistic.(μY)) - dataset = (C₁=C₁, W₁=W₁, W₂=W₂, T₁=T₁, T₂=T₂, Y₁=Y₁, Y₂) + dataset = (C₁=C₁, W₁=W₁, W₂=W₂, T₁=T₁, T₂=T₂, Y₁=Y₁, Y₂=Y₂) @compile_workload begin # all calls in this block will be precompiled, regardless of whether diff --git a/src/estimate.jl b/src/estimate.jl index be0ce081..769fa1ac 100644 --- a/src/estimate.jl +++ b/src/estimate.jl @@ -34,7 +34,8 @@ function Base.show(io::IO, r::TMLEResult) pretty_table(io, data;header=["Estimator", "Estimate", "95% Confidence Interval", "P-value"]) end - +tmle(result::TMLEResult) = result.tmle +ose(result::TMLEResult) = result.onestep """ estimate(r::AsymptoticallyLinearEstimate) diff --git a/src/scm.jl b/src/scm.jl index e5ff5a4a..f1f03495 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -97,8 +97,8 @@ end Base.show(io::IO, ::MIME"text/plain", eq::SE) = println(io, string_repr(eq)) -assign_model!(eq::SE, model::Nothing) = nothing -assign_model!(eq::SE, model::Model) = eq.model = model +setmodel!(eq::SE, model::Nothing) = nothing +setmodel!(eq::SE, model::Model) = eq.model = model outcome(se::SE) = se.outcome parents(se::SE) = se.parents diff --git a/src/treatment_transformer.jl b/src/treatment_transformer.jl index f31f3bc0..37cf3e60 100644 --- a/src/treatment_transformer.jl +++ b/src/treatment_transformer.jl @@ -38,3 +38,4 @@ function MLJBase.transform(model::TreatmentTransformer, fitresult, Xnew) return merge(Xt, ordered_factors) end +with_encoder(model; encoder=TreatmentTransformer()) = TreatmentTransformer() |> model \ No newline at end of file From 24ca32eb0cd35f98c93068dc43f7d59801084ed3 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 9 Aug 2023 15:23:33 +0100 Subject: [PATCH 030/151] update docs --- docs/make.jl | 2 +- docs/src/user_guide/adjustment.md | 3 +++ docs/src/user_guide/estimands.md | 14 ++++++++++ docs/src/user_guide/estimation.md | 43 +++++++++++++++++++------------ docs/src/user_guide/misc.md | 3 +++ docs/src/user_guide/scm.md | 4 +++ docs/src/walk_through.md | 4 +++ 7 files changed, 56 insertions(+), 17 deletions(-) create mode 100644 docs/src/user_guide/misc.md diff --git a/docs/make.jl b/docs/make.jl index 2f0edc14..a7405ebd 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -27,7 +27,7 @@ makedocs(; "Home" => "index.md", "Walk Through" => "walk_through.md", "User Guide" => [joinpath("user_guide", f) for f in - ("mathematical_setting.md", "scm.md", "estimands.md", "estimation.md", "adjustment.md")], + ("scm.md", "estimands.md", "estimation.md", "adjustment.md", "misc.md")], "Examples" => [ #joinpath("examples", "introduction_to_targeted_learning.md"), # joinpath("examples", "super_learning.md"), diff --git a/docs/src/user_guide/adjustment.md b/docs/src/user_guide/adjustment.md index 5bf7af90..e5078884 100644 --- a/docs/src/user_guide/adjustment.md +++ b/docs/src/user_guide/adjustment.md @@ -1,3 +1,6 @@ +```@meta +CurrentModule = TMLE +``` # Adjustment Methods In a `SCM`, each variable is determined by a set of parents and a statistical model describing the functional relationship between them. However, for the estimation of Causal Estimands the fitted models may not exactly correspond to the variable's equation in the `SCM`. Adjustment methods tell the estimation procedure which input variables should be incorporated in the statistical model fits. diff --git a/docs/src/user_guide/estimands.md b/docs/src/user_guide/estimands.md index a4d319da..b01cae47 100644 --- a/docs/src/user_guide/estimands.md +++ b/docs/src/user_guide/estimands.md @@ -1,3 +1,7 @@ +```@meta +CurrentModule = TMLE +``` + # Estimands Most causal questions can be translated into a causal estimand. Usually either an interventional or counterfactual quantity. What would have been the outcome if I had set this variable to this value? When identified, this causal estimand translates to a statistical estimand which can be estimated from data. For us, an estimand will be a functional, that is a function that takes as input a probability distribution, and outputs a real number or vector of real numbers. @@ -104,3 +108,13 @@ For a Structural Causal Model `scm`, an outcome `Y` and two treatments differenc ```julia Ψ = IATE(scm, outcome=:Y, treatment=(T₁=(case=t₁₂, control=t₁₁), T₂=(case=t₂₂, control=t₂₁))) ``` + +## Any function of the previous Estimands + +As a result of Julia's automatic differentiation facilities, given a set of already estimated estimands ``(\Psi_1, ..., \Psi_k)``, we can automatically compute an estimator for $f(\Psi_1, ..., \Psi_k)$. This is done via the `compose` function: + +```julia +compose(f, args...) +``` + +where args are asymptotically linear estimates (see [Composing Estimands](@ref)). diff --git a/docs/src/user_guide/estimation.md b/docs/src/user_guide/estimation.md index e2e4d4cc..33e3afdd 100644 --- a/docs/src/user_guide/estimation.md +++ b/docs/src/user_guide/estimation.md @@ -1,3 +1,7 @@ +```@meta +CurrentModule = TMLE +``` + # Estimation ## Estimating a single Estimand @@ -36,7 +40,7 @@ function make_dataset(;n=1000) Y = Y ) end -dataset = make_dataset() +dataset = make_dataset(n=10000) scm = SCM( SE(:Y, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], with_encoder(LinearRegressor())), SE(:T₁, [:W₁₁, :W₁₂], LogisticClassifier()), @@ -48,7 +52,7 @@ Once a `SCM` and an estimand have been defined, we can proceed with Targeted Est ```@example estimation Ψ₁ = ATE(scm, outcome=:Y, treatment=(T₁=(case=true, control=false),)) -result, fluctuation_mach = tmle(Ψ₁, dataset; +result₁, fluctuation_mach = tmle(Ψ₁, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=1, force=false, @@ -81,7 +85,7 @@ We could now get an interest in the Average Treatment Effect of `T₂`: ```@example estimation Ψ₂ = ATE(scm, outcome=:Y, treatment=(T₂=(case=true, control=false),)) -result, fluctuation_mach = tmle(Ψ₂, dataset; +result₂, fluctuation_mach = tmle(Ψ₂, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=1, force=false, @@ -98,7 +102,7 @@ Let's now see how the models can be reused with a new estimand, say the Total Av ```@example estimation Ψ₃ = ATE(scm, outcome=:Y, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false))) -result, fluctuation_mach = tmle(Ψ₃, dataset; +result₃, fluctuation_mach = tmle(Ψ₃, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=1, force=false, @@ -111,7 +115,7 @@ This time only the statistical model for `Y` is fitted again while reusing the m ```@example estimation Ψ₄ = IATE(scm, outcome=:Y, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false))) -result, fluctuation_mach = tmle(Ψ₄, dataset; +result₄, fluctuation_mach = tmle(Ψ₄, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=1, force=false, @@ -134,21 +138,28 @@ optimize_ordering([Ψ₃, Ψ₁, Ψ₂, Ψ₄]) == [Ψ₁, Ψ₃, Ψ₄, Ψ₂] By leveraging the multivariate Central Limit Theorem and Julia's automatic differentiation facilities, we can estimate any estimand which is a function of already estimated estimands. By default, TMLE.jl will use [Zygote](https://fluxml.ai/Zygote.jl/latest/) but since we are using [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) you can change the backend to your favorite AD system. -For instance, by definition of the ATE, we should be able to retrieve ``ATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1}`` by composing ``CM_{T_1=1, T_2=1} - CM_{T_1=0, T_2=0}``. We already have almost all of the pieces, we just need an estimate for ``CM_{T_1=0, T_2=0}``, let's get it. +For instance, by definition of the ``IATE``, we should be able to retrieve: -```@example estimation -Ψ = CM( - outcome = :Y, - treatment = (T₁=false, T₂=false), - confounders = [:W₁, :W₂] -) -cm_result₀₀, _ = tmle(Ψ, η_spec, dataset, verbosity=0) -nothing # hide +```math +IATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1} = ATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1} - ATE_{T_1=0, T_2=0 \rightarrow 1} - ATE_{T_1=0 \rightarrow 1, T_2=0} ``` ```@example estimation -composed_ate_result = compose(-, cm_result₁₁.tmle, cm_result₀₀.tmle) -nothing # hide +first_ate = ATE(scm, outcome=:Y, treatment=(T₁=(case=true, control=false), T₂=(case=false, control=false))) +first_ate_result, _ = tmle(first_ate, dataset) + +second_ate = ATE(scm, outcome=:Y, treatment=(T₁=(case=false, control=false), T₂=(case=true, control=false))) +second_ate_result, _ = tmle(second_ate, dataset) + +composed_iate_result = compose( + (x, y, z) -> x - y - z, + tmle(result₃), tmle(first_ate_result), tmle(second_ate_result) +) +isapprox( + TMLE.estimate(tmle(result₄)), + TMLE.estimate(composed_iate_result), + atol=0.1 +) ``` ## Weighted Fluctuation diff --git a/docs/src/user_guide/misc.md b/docs/src/user_guide/misc.md new file mode 100644 index 00000000..4ac9857c --- /dev/null +++ b/docs/src/user_guide/misc.md @@ -0,0 +1,3 @@ +# Miscellaneous + +## Treatment Transformer diff --git a/docs/src/user_guide/scm.md b/docs/src/user_guide/scm.md index 66d95057..a0905922 100644 --- a/docs/src/user_guide/scm.md +++ b/docs/src/user_guide/scm.md @@ -1,3 +1,7 @@ +```@meta +CurrentModule = TMLE +``` + # Structural Causal Models In TMLE.jl, everything starts from the definition of a Structural Causal Model (`SCM`). A `SCM` in a series of Structural Equations (`SE`) that describe the causal relationships between the random variables under study. The purpose of this package is not to infer the `SCM`, instead we assume it is provided by the user. There are multiple ways one can define a `SCM` that we now describe. diff --git a/docs/src/walk_through.md b/docs/src/walk_through.md index 8686cb40..74160acc 100644 --- a/docs/src/walk_through.md +++ b/docs/src/walk_through.md @@ -1,3 +1,7 @@ +```@meta +CurrentModule = TMLE +``` + # Walk Through The goal of this section is to provide a comprehensive (but non-exhaustive) illustration of the estimation process provided in TMLE.jl. For an in-depth explanation, please refer to the User Guide. From 3342d4914a64d417139361f5c576f8b6c07759b9 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 9 Aug 2023 16:32:34 +0100 Subject: [PATCH 031/151] add some constructors for estimands --- docs/src/user_guide/misc.md | 17 +++++++++++++++ src/TMLE.jl | 8 +++++-- src/estimands.jl | 42 ++++++++++++++++++++++-------------- src/treatment_transformer.jl | 15 +++++++------ 4 files changed, 58 insertions(+), 24 deletions(-) diff --git a/docs/src/user_guide/misc.md b/docs/src/user_guide/misc.md index 4ac9857c..eeb7a641 100644 --- a/docs/src/user_guide/misc.md +++ b/docs/src/user_guide/misc.md @@ -1,3 +1,20 @@ # Miscellaneous ## Treatment Transformer + +To account for the fact that treatment variables are categorical variables we provide a MLJ compliant transformer that will either: + +- Retrieve the floating point representation of a treatment it it has a natural ordering +- One hot encode it otherwise + +Such transformer can be created with: + +```@example +TreatmentTransformer(;encoder=encoder()) +``` + +where `encoder` is a [OneHotEncoder](https://alan-turing-institute.github.io/MLJ.jl/dev/models/OneHotEncoder_MLJModels/#OneHotEncoder_MLJModels). + +The `with_encoder(model; encoder=TreatmentTransformer())` provides a shorthand to combine a `TreatmentTransformer` with another MLJ model in a pipeline. + +Of course you are also free to define your own strategy! \ No newline at end of file diff --git a/src/TMLE.jl b/src/TMLE.jl index 15c52709..81d56977 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -26,7 +26,10 @@ import AbstractDifferentiation as AD export SE, StructuralEquation export StructuralCausalModel, SCM, StaticConfoundedModel export setmodel!, equations, reset!, parents -export CM, ATE, IATE, AVAILABLE_ESTIMANDS +export ConditionalMean, CM +export AverageTreatmentEffect, ATE +export InteractionAverageTreatmentEffect, IATE +export AVAILABLE_ESTIMANDS export fit!, optimize_ordering, optimize_ordering! export tmle export var, estimate, initial_estimate, OneSampleTTest, OneSampleZTest, pvalue, confint @@ -40,12 +43,13 @@ export DAG, graphplot # ############################################################################# include("scm.jl") -include("treatment_transformer.jl") include("estimands.jl") include("adjustment.jl") include("utils.jl") include("estimation.jl") include("estimate.jl") +include("treatment_transformer.jl") + # ############################################################################# # PRECOMPILATION WORKLOAD diff --git a/src/estimands.jl b/src/estimands.jl index 6beaa668..6ff0673e 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -44,11 +44,6 @@ end const CM = ConditionalMean -CM(;scm::SCM, outcome, treatment) = CM(scm, outcome, treatment) -CM(scm::SCM; outcome, treatment) = CM(scm, outcome, treatment) - -name(::Type{CM}) = "CM" - ##################################################################### ### Average Treatment Effect ### ##################################################################### @@ -86,11 +81,6 @@ end const ATE = AverageTreatmentEffect -ATE(;scm::SCM, outcome, treatment) = ATE(scm, outcome, treatment) -ATE(scm::SCM; outcome, treatment) = ATE(scm, outcome, treatment) - -name(::Type{ATE}) = "ATE" - ##################################################################### ### Interaction Average Treatment Effect ### ##################################################################### @@ -131,18 +121,38 @@ end const IATE = InteractionAverageTreatmentEffect -IATE(;scm, outcome, treatment) = IATE(scm, outcome, treatment) -IATE(scm; outcome, treatment) = IATE(scm, outcome, treatment) - -name(::Type{IATE}) = "IATE" - ##################################################################### ### Methods ### ##################################################################### AVAILABLE_ESTIMANDS = (CM, ATE, IATE) +CMCompositeTypenames = [:CM, :ATE, :IATE] +CMCompositeEstimand = Union{(eval(x) for x in CMCompositeTypenames)...} + +# Define constructors/name for CMCompositeEstimand types +for typename in CMCompositeTypenames + ex = quote + name(::Type{$(typename)}) = string($(typename)) + + $(typename)(;scm::SCM, outcome::Symbol, treatment::NamedTuple) = $(typename)(scm, outcome, treatment) + $(typename)(scm::SCM; outcome::Symbol, treatment::NamedTuple) = $(typename)(scm, outcome, treatment) + + function $(typename)( + outcome::Symbol, treatment::NamedTuple, confounders::Union{Symbol, AbstractVector{Symbol}}; + covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, + outcome_model = with_encoder(LinearRegressor()), + treatment_model = LinearBinaryClassifier()) + scm = StaticConfoundedModel(outcome, collect(keys(treatment)), confounders; + covariates=covariates, + outcome_model=outcome_model, + treatment_model=treatment_model + ) + return $(typename)(scm; outcome=outcome, treatment=treatment) + end + end -CMCompositeEstimand = Union{CM, ATE, IATE} + eval(ex) +end VariableNotAChildInSCMError(variable) = ArgumentError(string("Variable ", variable, " is not associated with a Structural Equation in the SCM.")) TreatmentMustBeInOutcomeParentsError(variable) = ArgumentError(string("Treatment variable ", variable, " must be a parent of the outcome.")) diff --git a/src/treatment_transformer.jl b/src/treatment_transformer.jl index 37cf3e60..ae3c9d3e 100644 --- a/src/treatment_transformer.jl +++ b/src/treatment_transformer.jl @@ -1,15 +1,18 @@ -""" -Treatments in TMLE are represented by `CategoricalArrays`. If a treatment column -has type `OrderedFactor`, then its integer representation is used, make sure that -the levels correspond to your expectations. All other columns are one-hot encoded. -""" + mutable struct TreatmentTransformer <: MLJBase.Unsupervised encoder::OneHotEncoder end encoder() = OneHotEncoder(drop_last=true, ordered_factor=false) +""" + TreatmentTransformer(;encoder=encoder()) + +Treatments in TMLE are represented by `CategoricalArrays`. If a treatment column +has type `OrderedFactor`, then its integer representation is used, make sure that +the levels correspond to your expectations. All other columns are one-hot encoded. +""" TreatmentTransformer(;encoder=encoder()) = TreatmentTransformer(encoder) MLJBase.fit(model::TreatmentTransformer, verbosity::Int, X) = @@ -38,4 +41,4 @@ function MLJBase.transform(model::TreatmentTransformer, fitresult, Xnew) return merge(Xt, ordered_factors) end -with_encoder(model; encoder=TreatmentTransformer()) = TreatmentTransformer() |> model \ No newline at end of file +with_encoder(model; encoder=TreatmentTransformer()) = TreatmentTransformer(;encoder=encoder) |> model \ No newline at end of file From 7428385e40d10549bfa9a1d9be43c5be841a60fc Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 9 Aug 2023 16:54:06 +0100 Subject: [PATCH 032/151] fix some names --- docs/src/index.md | 2 +- docs/src/user_guide/estimation.md | 16 ++++++++-------- docs/src/walk_through.md | 4 ++-- examples/double_robustness.jl | 2 +- examples/super_learning.jl | 4 ++-- src/TMLE.jl | 2 +- src/estimands.jl | 2 +- src/estimate.jl | 12 ++++++------ src/estimation.jl | 4 ++-- src/treatment_transformer.jl | 2 +- test/3points_interactions.jl | 4 ++-- test/composition.jl | 14 +++++++------- test/double_robustness_ate.jl | 20 ++++++++++---------- test/double_robustness_iate.jl | 12 ++++++------ test/estimands.jl | 12 ++++++++++++ test/estimation.jl | 24 ++++++++++++------------ test/missing_management.jl | 2 +- test/non_regression_test.jl | 4 ++-- 18 files changed, 77 insertions(+), 65 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 35a456d2..4af17b7d 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -81,7 +81,7 @@ nothing # hide Running the `tmle` will produce two asymptotically linear estimators: the TMLE and the One Step Estimator. For each we can look at the associated estimate, confidence interval and p-value: ```@example quick-start -result, _ = tmle(Ψ, dataset) +result, _ = tmle!(Ψ, dataset) result ``` diff --git a/docs/src/user_guide/estimation.md b/docs/src/user_guide/estimation.md index 33e3afdd..75dd7d4f 100644 --- a/docs/src/user_guide/estimation.md +++ b/docs/src/user_guide/estimation.md @@ -52,7 +52,7 @@ Once a `SCM` and an estimand have been defined, we can proceed with Targeted Est ```@example estimation Ψ₁ = ATE(scm, outcome=:Y, treatment=(T₁=(case=true, control=false),)) -result₁, fluctuation_mach = tmle(Ψ₁, dataset; +result₁, fluctuation_mach = tmle!(Ψ₁, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=1, force=false, @@ -85,7 +85,7 @@ We could now get an interest in the Average Treatment Effect of `T₂`: ```@example estimation Ψ₂ = ATE(scm, outcome=:Y, treatment=(T₂=(case=true, control=false),)) -result₂, fluctuation_mach = tmle(Ψ₂, dataset; +result₂, fluctuation_mach = tmle!(Ψ₂, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=1, force=false, @@ -102,7 +102,7 @@ Let's now see how the models can be reused with a new estimand, say the Total Av ```@example estimation Ψ₃ = ATE(scm, outcome=:Y, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false))) -result₃, fluctuation_mach = tmle(Ψ₃, dataset; +result₃, fluctuation_mach = tmle!(Ψ₃, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=1, force=false, @@ -115,7 +115,7 @@ This time only the statistical model for `Y` is fitted again while reusing the m ```@example estimation Ψ₄ = IATE(scm, outcome=:Y, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false))) -result₄, fluctuation_mach = tmle(Ψ₄, dataset; +result₄, fluctuation_mach = tmle!(Ψ₄, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=1, force=false, @@ -146,18 +146,18 @@ IATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1} = ATE_{T_1=0 \rightarrow 1, T_2= ```@example estimation first_ate = ATE(scm, outcome=:Y, treatment=(T₁=(case=true, control=false), T₂=(case=false, control=false))) -first_ate_result, _ = tmle(first_ate, dataset) +first_ate_result, _ = tmle!(first_ate, dataset) second_ate = ATE(scm, outcome=:Y, treatment=(T₁=(case=false, control=false), T₂=(case=true, control=false))) -second_ate_result, _ = tmle(second_ate, dataset) +second_ate_result, _ = tmle!(second_ate, dataset) composed_iate_result = compose( (x, y, z) -> x - y - z, tmle(result₃), tmle(first_ate_result), tmle(second_ate_result) ) isapprox( - TMLE.estimate(tmle(result₄)), - TMLE.estimate(composed_iate_result), + estimate(tmle(result₄)), + estimate(composed_iate_result), atol=0.1 ) ``` diff --git a/docs/src/walk_through.md b/docs/src/walk_through.md index 74160acc..b742c896 100644 --- a/docs/src/walk_through.md +++ b/docs/src/walk_through.md @@ -129,7 +129,7 @@ iate = IATE( Then each parameter can be estimated by calling the `tmle` function. For example: ```@example user-guide -result, _ = tmle(cm, dataset) +result, _ = tmle!(cm, dataset) result ``` @@ -142,7 +142,7 @@ The `result` contains 3 main elements: The adjustment set is determined by the provided `adjustment_method` keyword. At the moment, only `BackdoorAdjustment` is available. However one can specify that extra covariates could be used to fit the outcome model. ```@example user-guide -result, _ = tmle(iate, dataset;adjustment_method=BackdoorAdjustment([:C])) +result, _ = tmle!(iate, dataset;adjustment_method=BackdoorAdjustment([:C])) result ``` diff --git a/examples/double_robustness.jl b/examples/double_robustness.jl index d7ef7541..d9de26c6 100644 --- a/examples/double_robustness.jl +++ b/examples/double_robustness.jl @@ -159,7 +159,7 @@ function tmle_inference(data) confounders=[:W] ) η_spec = NuisanceSpec(LinearRegressor(), LinearBinaryClassifier()) - result, _ = tmle(Ψ, η_spec, data; verbosity=0) + result, _ = tmle!(Ψ, η_spec, data; verbosity=0) tmleresult = result.tmle lb, ub = confint(OneSampleTTest(tmleresult)) return (TMLE.estimate(tmleresult), lb, ub) diff --git a/examples/super_learning.jl b/examples/super_learning.jl index d6e7a1b3..02b9ac71 100644 --- a/examples/super_learning.jl +++ b/examples/super_learning.jl @@ -194,7 +194,7 @@ nothing # hide #= Finally run the TMLE procedure and check the result =# -tmle_result, cache = tmle(Ψ, η_spec, dataset) +tmle_result, cache = tmle!Ψ, η_spec, dataset) test_result = OneSampleTTest(tmle_result.tmle, ψ₀) @@ -207,7 +207,7 @@ Now, what if we had used linear models only instead of the Super Learner? This i LogisticClassifier(lambda=0) ) -tmle_result_linear, cache = tmle(Ψ, η_spec_linear, dataset) +tmle_result_linear, cache = tmle!Ψ, η_spec_linear, dataset) test_result_linear = OneSampleTTest(tmle_result_linear.tmle, ψ₀) diff --git a/src/TMLE.jl b/src/TMLE.jl index 81d56977..e9cf2fbe 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -31,7 +31,7 @@ export AverageTreatmentEffect, ATE export InteractionAverageTreatmentEffect, IATE export AVAILABLE_ESTIMANDS export fit!, optimize_ordering, optimize_ordering! -export tmle +export tmle!, tmle, ose export var, estimate, initial_estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose export TreatmentTransformer, with_encoder diff --git a/src/estimands.jl b/src/estimands.jl index 6ff0673e..b8678a0e 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -142,7 +142,7 @@ for typename in CMCompositeTypenames covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, outcome_model = with_encoder(LinearRegressor()), treatment_model = LinearBinaryClassifier()) - scm = StaticConfoundedModel(outcome, collect(keys(treatment)), confounders; + scm = StaticConfoundedModel([outcome], collect(keys(treatment)), confounders; covariates=covariates, outcome_model=outcome_model, treatment_model=treatment_model diff --git a/src/estimate.jl b/src/estimate.jl index 769fa1ac..ee818fc7 100644 --- a/src/estimate.jl +++ b/src/estimate.jl @@ -27,8 +27,8 @@ function Base.show(io::IO, r::TMLEResult) tmletest = OneSampleTTest(r.tmle) onesteptest = OneSampleTTest(r.onestep) data = [ - :TMLE estimate(r.tmle) confint(tmletest) pvalue(tmletest); - :OSE estimate(r.onestep) confint(onesteptest) pvalue(onesteptest); + :TMLE estimate(tmle(r)) confint(tmletest) pvalue(tmletest); + :OSE estimate(ose(r)) confint(onesteptest) pvalue(onesteptest); :Naive r.initial nothing nothing ] pretty_table(io, data;header=["Estimator", "Estimate", "95% Confidence Interval", "P-value"]) @@ -38,18 +38,18 @@ tmle(result::TMLEResult) = result.tmle ose(result::TMLEResult) = result.onestep """ - estimate(r::AsymptoticallyLinearEstimate) + Distributions.estimate(r::AsymptoticallyLinearEstimate) Retrieves the final estimate: after the TMLE step. """ -estimate(r::AsymptoticallyLinearEstimate) = r.Ψ̂ +Distributions.estimate(r::AsymptoticallyLinearEstimate) = r.Ψ̂ """ - estimate(r::ComposedTMLEstimate) + Distributions.estimate(r::ComposedTMLEstimate) Retrieves the final estimate: after the TMLE step. """ -estimate(r::ComposedTMLEstimate) = length(r.Ψ̂) == 1 ? r.Ψ̂[1] : r.Ψ̂ +Distributions.estimate(r::ComposedTMLEstimate) = length(r.Ψ̂) == 1 ? r.Ψ̂[1] : r.Ψ̂ """ var(r::AsymptoticallyLinearEstimate) diff --git a/src/estimation.jl b/src/estimation.jl index b7826ca9..a7b37488 100644 --- a/src/estimation.jl +++ b/src/estimation.jl @@ -10,7 +10,7 @@ function MLJBase.fit!(Ψ::CMCompositeEstimand, dataset; adjustment_method=Backdo end """ - tmle(Ψ::CMCompositeEstimand, dataset; + tmle!(Ψ::CMCompositeEstimand, dataset; adjustment_method=BackdoorAdjustment(), verbosity=1, force=false, @@ -30,7 +30,7 @@ Performs Targeted Minimum Loss Based Estimation of the target estimand. - threshold: The balancing score will be bounded to respect this threshold. - weighted_fluctuation: To use a weighted fluctuation instead of the vanilla TMLE, can improve stability. """ -function tmle(Ψ::CMCompositeEstimand, dataset; +function tmle!(Ψ::CMCompositeEstimand, dataset; adjustment_method=BackdoorAdjustment(), verbosity=1, force=false, diff --git a/src/treatment_transformer.jl b/src/treatment_transformer.jl index ae3c9d3e..8980647a 100644 --- a/src/treatment_transformer.jl +++ b/src/treatment_transformer.jl @@ -41,4 +41,4 @@ function MLJBase.transform(model::TreatmentTransformer, fitresult, Xnew) return merge(Xt, ordered_factors) end -with_encoder(model; encoder=TreatmentTransformer()) = TreatmentTransformer(;encoder=encoder) |> model \ No newline at end of file +with_encoder(model; encoder=encoder()) = TreatmentTransformer(;encoder=encoder) |> model \ No newline at end of file diff --git a/test/3points_interactions.jl b/test/3points_interactions.jl index b9b850f7..7eb6da1e 100644 --- a/test/3points_interactions.jl +++ b/test/3points_interactions.jl @@ -40,12 +40,12 @@ end treatment = (T₁=(case=true, control=false), T₂=(case=true, control=false), T₃=(case=1, control=0)) ) - result, fluctuation = tmle(Ψ, dataset, verbosity=0) + result, fluctuation = tmle!(Ψ, dataset, verbosity=0) test_coverage(result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation) test_mean_inf_curve_almost_zero(result; atol=1e-10) - result, fluctuation = tmle(Ψ, dataset, verbosity=0, weighted_fluctuation=true) + result, fluctuation = tmle!(Ψ, dataset, verbosity=0, weighted_fluctuation=true) test_coverage(result, Ψ₀) test_mean_inf_curve_almost_zero(result; atol=1e-10) diff --git a/test/composition.jl b/test/composition.jl index 4f2cf571..04bd1e53 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -40,14 +40,14 @@ end outcome = :Y, treatment = (T=1,) ) - CM_result₁, _ = tmle(CM₁, dataset, verbosity=0) + CM_result₁, _ = tmle!(CM₁, dataset, verbosity=0) # Conditional Mean T = 0 CM₀ = CM( scm = scm, outcome = :Y, treatment = (T=0,) ) - CM_result₀, _ = tmle(CM₀, dataset, verbosity=0) + CM_result₀, _ = tmle!(CM₀, dataset, verbosity=0) # Via Composition CM_result_composed = compose(-, CM_result₁.tmle, CM_result₀.tmle) @@ -57,8 +57,8 @@ end outcome = :Y, treatment = (T=(case=1, control=0),), ) - ATE_result₁₀, _ = tmle(ATE₁₀, dataset, verbosity=0) - @test TMLE.estimate(ATE_result₁₀.tmle) ≈ TMLE.estimate(CM_result_composed) atol = 1e-7 + ATE_result₁₀, _ = tmle!(ATE₁₀, dataset, verbosity=0) + @test estimate(ATE_result₁₀.tmle) ≈ estimate(CM_result_composed) atol = 1e-7 @test var(ATE_result₁₀.tmle) ≈ var(CM_result_composed) atol = 1e-7 end @@ -69,18 +69,18 @@ end outcome = :Y, treatment = (T=1,) ) - CM_result₁, _ = tmle(CM₁, dataset, verbosity=0) + CM_result₁, _ = tmle!(CM₁, dataset, verbosity=0) CM₀ = CM( scm = scm, outcome = :Y, treatment = (T=0,) ) - CM_result₀, _ = tmle(CM₀, dataset, verbosity=0) + CM_result₀, _ = tmle!(CM₀, dataset, verbosity=0) f(x, y) = [x^2 - y, x/y, 2x + 3y] CM_result_composed = compose(f, CM_result₁.tmle, CM_result₀.tmle) - @test TMLE.estimate(CM_result_composed) == f(TMLE.estimate(CM_result₁.tmle), TMLE.estimate(CM_result₀.tmle)) + @test estimate(CM_result_composed) == f(estimate(CM_result₁.tmle), TMLE.estimate(CM_result₀.tmle)) @test size(var(CM_result_composed)) == (3, 3) end diff --git a/test/double_robustness_ate.jl b/test/double_robustness_ate.jl index 5b4865a5..1c7a5cf9 100644 --- a/test/double_robustness_ate.jl +++ b/test/double_robustness_ate.jl @@ -124,7 +124,7 @@ end scm.Y.model = TreatmentTransformer() |> MLJModels.DeterministicConstantRegressor() scm.T.model = LogisticClassifier(lambda=0) - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -136,7 +136,7 @@ end scm.Y.model = TreatmentTransformer() |> LinearRegressor() scm.T.model = ConstantClassifier() - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -153,7 +153,7 @@ end scm.Y.model = TreatmentTransformer() |> ConstantClassifier() scm.T.model = LogisticClassifier(lambda=0) - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-6) @@ -165,7 +165,7 @@ end scm.Y.model = TreatmentTransformer() |> LogisticClassifier(lambda=0) scm.T.model = ConstantClassifier() - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-6) @@ -183,7 +183,7 @@ end scm.Y.model = TreatmentTransformer() |> MLJModels.DeterministicConstantRegressor() scm.T.model = LogisticClassifier(lambda=0) - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -195,7 +195,7 @@ end scm.Y.model = TreatmentTransformer() |> LinearRegressor() scm.T.model = ConstantClassifier() - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -214,7 +214,7 @@ end scm.T₁.model = LogisticClassifier(lambda=0) scm.T₂.model = LogisticClassifier(lambda=0) - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₁) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -224,7 +224,7 @@ end scm.T₁.model = ConstantClassifier() scm.T₂.model = ConstantClassifier() - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₁) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -235,7 +235,7 @@ end outcome = :Y, treatment = (T₁=(case=1., control=0.), T₂=(case=1., control=0.)), ) - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -246,7 +246,7 @@ end scm.T₁.model = LogisticClassifier(lambda=0) scm.T₂.model = LogisticClassifier(lambda=0) - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) diff --git a/test/double_robustness_iate.jl b/test/double_robustness_iate.jl index fb89060d..be1757b9 100644 --- a/test/double_robustness_iate.jl +++ b/test/double_robustness_iate.jl @@ -187,7 +187,7 @@ end scm.T₁.model = LogisticClassifier(lambda=0) scm.T₂.model = LogisticClassifier(lambda=0) - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0); + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-9) @@ -200,7 +200,7 @@ end scm.T₁.model = ConstantClassifier() scm.T₂.model = ConstantClassifier() - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0); + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-9) @@ -221,7 +221,7 @@ end scm.T₁.model = LogisticClassifier(lambda=0) scm.T₂.model = LogisticClassifier(lambda=0) - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -234,7 +234,7 @@ end scm.T₁.model = ConstantClassifier() scm.T₂.model = ConstantClassifier() - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0) + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -253,7 +253,7 @@ end scm.T₁.model = LogisticClassifier(lambda=0) scm.T₂.model = LogisticClassifier(lambda=0) - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0); + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) @@ -265,7 +265,7 @@ end scm.T₁.model = ConstantClassifier() scm.T₂.model = ConstantClassifier() - tmle_result, fluctuation_mach = tmle(Ψ, dataset, verbosity=0); + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_fluct_mean_inf_curve_lower_than_initial(tmle_result) diff --git a/test/estimands.jl b/test/estimands.jl index 847b13d7..3dfb1f51 100644 --- a/test/estimands.jl +++ b/test/estimands.jl @@ -72,6 +72,18 @@ using MLJGLMInterface @test TMLE.outcome(Ψ) == :Z end +@testset "Test constructors" begin + Ψ = CM(:Y, (T=1,), [:W]) + @test parents(Ψ.scm.Y) == [:T, :W] + + Ψ = ATE(:Y, (T₁=(case=1, control=0), T₂=(case=1, control=0)), :W) + @test parents(Ψ.scm.Y) == [:T₁, :T₂, :W] + + Ψ = IATE(:Y, (T₁=(case=1, control=0), T₂=(case=1, control=0)), :W) + @test parents(Ψ.scm.Y) == [:T₁, :T₂, :W] +end + + @testset "Test fit!" begin rng = StableRNG(123) n = 100 diff --git a/test/estimation.jl b/test/estimation.jl index dab5134e..2178a744 100644 --- a/test/estimation.jl +++ b/test/estimation.jl @@ -74,7 +74,7 @@ table_types = (Tables.columntable, DataFrame) (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset; verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset; verbosity=1); # The TMLE covers the ground truth but the initial estimate does not Ψ₀ = -1 @test tmle_result.estimand == Ψ @@ -94,7 +94,7 @@ table_types = (Tables.columntable, DataFrame) (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = 1 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) @@ -115,7 +115,7 @@ table_types = (Tables.columntable, DataFrame) (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = 0.5 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) @@ -135,7 +135,7 @@ table_types = (Tables.columntable, DataFrame) (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = 10 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) @@ -154,7 +154,7 @@ table_types = (Tables.columntable, DataFrame) (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = -0.5 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) @@ -172,7 +172,7 @@ table_types = (Tables.columntable, DataFrame) (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = -1.5 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) @@ -197,7 +197,7 @@ end (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = 3 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) @@ -215,7 +215,7 @@ end (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = 2.5 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) @@ -234,7 +234,7 @@ end (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = 1.5 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) @@ -249,7 +249,7 @@ end (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) @@ -272,7 +272,7 @@ end (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = -1 @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) @@ -290,7 +290,7 @@ end (:info, "Performing TMLE..."), (:info, "Done.") ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle(Ψ, dataset, verbosity=1); + tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) diff --git a/test/missing_management.jl b/test/missing_management.jl index f38b8085..297c69ac 100644 --- a/test/missing_management.jl +++ b/test/missing_management.jl @@ -47,7 +47,7 @@ end scm = scm, outcome=:Y, treatment=(T=(case=1, control=0),)) - tmle_result, fluctuation_mach = tmle(Ψ, dataset; verbosity=0) + tmle_result, fluctuation_mach = tmle!(Ψ, dataset; verbosity=0) test_coverage(tmle_result, 1) test_fluct_decreases_risk(Ψ, fluctuation_mach) end diff --git a/test/non_regression_test.jl b/test/non_regression_test.jl index 5f28f130..17653073 100644 --- a/test/non_regression_test.jl +++ b/test/non_regression_test.jl @@ -23,9 +23,9 @@ using MLJGLMInterface ) Ψ = ATE(scm=scm, outcome=:haz01, treatment=(parity01=(case=1, control=0),)) - result, fluctuation_mach = tmle(Ψ, dataset, threshold=0.025, verbosity=0) + result, fluctuation_mach = tmle!(Ψ, dataset, threshold=0.025, verbosity=0) l, u = confint(OneSampleTTest(result.tmle)) - @test TMLE.estimate(result.tmle) ≈ -0.185533 atol = 1e-6 + @test estimate(result.tmle) ≈ -0.185533 atol = 1e-6 @test l ≈ -0.279246 atol = 1e-6 @test u ≈ -0.091821 atol = 1e-6 end From 12be74c7e7919509e720ab29220ceb6a0ead55f8 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 9 Aug 2023 17:19:37 +0100 Subject: [PATCH 033/151] more docs --- docs/src/index.md | 59 ++++++++++++++++++------------- docs/src/user_guide/adjustment.md | 1 + docs/src/user_guide/misc.md | 1 + docs/src/user_guide/scm.md | 4 +++ docs/src/walk_through.md | 19 +++++----- 5 files changed, 51 insertions(+), 33 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 4af17b7d..6d02e824 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -20,9 +20,9 @@ Pkg> add TMLE To run an estimation procedure, we need 3 ingredients: -1. A Structural Causal Model that describes the relationship between the variables. -2. An estimand of interest -3. A dataset +1. A dataset +2. A Structural Causal Model that describes the relationship between the variables. +3. An estimand of interest For illustration, assume we know the actual data generating process is as follows: @@ -34,15 +34,45 @@ Y &\sim \mathcal{Normal}(1 + 3 \cdot T - T \cdot W, 0.01) \end{aligned} ``` -Let's define the associated Structural Causal Model: +Because we know the data generating process, we can simulate some data accordingly: ```@example quick-start using TMLE +using Distributions +using StableRNGs +using Random +using CategoricalArrays +using MLJLinearModels +using LogExpFunctions + +rng = StableRNG(123) +n = 100 +W = rand(rng, Uniform(), n) +T = rand(rng, Uniform(), n) .< logistic.(1 .- 2W) +Y = 1 .+ 3T .- T.*W .+ rand(rng, Normal(0, 0.01), n) +dataset = (Y=Y, T=categorical(T), W=W) +nothing # hide +``` + +### Two lines TMLE + +Estimating the Average Treatment Effect can of ``T`` on ``Y`` can be as simple as: + +```@example quick-start +Ψ = ATE(:Y, (T=(case=true, control = false),), :W) +result, _ = tmle!(Ψ, dataset) +result +``` +### Two steps approach + +Let's first define the Structural Causal Model: + +```@example quick-start scm = StaticConfoundedModel(:Y, :T, :W) ``` -Now, say we are interested in the Average Treatment Effect of the treatment ``T`` on the outcome ``Y``: $ATE_{T:0 \rightarrow 1}(Y)$: +and second, define the Average Treatment Effect of the treatment ``T`` on the outcome ``Y``: $ATE_{T:0 \rightarrow 1}(Y)$: ```@example quick-start Ψ = ATE( @@ -59,25 +89,6 @@ Note that in this example the ATE can be computed exactly and is given by: ATE_{0 \rightarrow 1}(P_0) = \mathbb{E}[1 + 3 - W] - \mathbb{E}[1] = 3 - \mathbb{E}[W] = 2.5 ``` -Because we know the data generating process, we can simulate some data accordingly: - -```@example quick-start -using Distributions -using StableRNGs -using Random -using CategoricalArrays -using MLJLinearModels -using LogExpFunctions - -rng = StableRNG(123) -n = 100 -W = rand(rng, Uniform(), n) -T = rand(rng, Uniform(), n) .< logistic.(1 .- 2W) -Y = 1 .+ 3T .- T.*W .+ rand(rng, Normal(0, 0.01), n) -dataset = (Y=Y, T=categorical(T), W=W) -nothing # hide -``` - Running the `tmle` will produce two asymptotically linear estimators: the TMLE and the One Step Estimator. For each we can look at the associated estimate, confidence interval and p-value: ```@example quick-start diff --git a/docs/src/user_guide/adjustment.md b/docs/src/user_guide/adjustment.md index e5078884..d0d87293 100644 --- a/docs/src/user_guide/adjustment.md +++ b/docs/src/user_guide/adjustment.md @@ -10,6 +10,7 @@ In a `SCM`, each variable is determined by a set of parents and a statistical mo At the moment we provide a single adjustment method, namely the Backdoor adjustment method. The adjustment set consists of all the treatment variable's parents. Additional covariates used to fit the outcome model can be provided via `outcome_extra`. ```@example +using TMLE # hide BackdoorAdjustment(;outcome_extra=[:C]) nothing ``` diff --git a/docs/src/user_guide/misc.md b/docs/src/user_guide/misc.md index eeb7a641..8fced69c 100644 --- a/docs/src/user_guide/misc.md +++ b/docs/src/user_guide/misc.md @@ -10,6 +10,7 @@ To account for the fact that treatment variables are categorical variables we pr Such transformer can be created with: ```@example +using TMLE # hide TreatmentTransformer(;encoder=encoder()) ``` diff --git a/docs/src/user_guide/scm.md b/docs/src/user_guide/scm.md index a0905922..53e56efa 100644 --- a/docs/src/user_guide/scm.md +++ b/docs/src/user_guide/scm.md @@ -11,6 +11,7 @@ In TMLE.jl, everything starts from the definition of a Structural Causal Model ( All models are wrong? Well maybe not the following: ```@example scm-incremental +using TMLE # hide scm = SCM() ``` @@ -48,6 +49,7 @@ push!(scm, SE(:T₂, [:W₂₁, :W₂₂, :W])) Instead of constructing the `SCM` incrementally, one can provide all the specified equations at once: ```@example scm-one-step +using TMLE # hide scm = SCM( SE(:Y, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :W, :C], with_encoder(LinearRegressor())), SE(:T₁, [:W₁₁, :W₁₂, :W], model=LogisticClassifier()), @@ -62,6 +64,7 @@ Noting that we have used the `with_encoder` function to reflect the fact that we There are many cases where we are interested in estimating the causal effect of a single treatment variable on a single outcome. Because it is typically only necessary to adjust for backdoor variables in order to identify this causal effect, we provide the `StaticConfoundedModel` interface to build such `SCM`: ```@example static-scm-1 +using TMLE # hide scm = StaticConfoundedModel( :Y, :T, [:W₁, :W₂]; covariates=[:C], @@ -75,6 +78,7 @@ The optional `covariates` are variables that influence the outcome but are not c This model can be extended to a plate-model with multiple treatments and multiple outcomes. In this case the set of confounders is assumed to confound all treatments which are in turn assumed to impact all outcomes. This can be defined as: ```@example static-scm-2 +using TMLE # hide scm = StaticConfoundedModel( [:Y₁, :Y₂], [:T₁, :T₂], [:W₁, :W₂]; covariates=[:C], diff --git a/docs/src/walk_through.md b/docs/src/walk_through.md index b742c896..a7219788 100644 --- a/docs/src/walk_through.md +++ b/docs/src/walk_through.md @@ -14,7 +14,8 @@ TMLE.jl is compatible with any dataset respecting the [Tables.jl](https://tables The dataset is generated as follows: -```@example user-guide +```@example walk-through +using TMLE using Random using Distributions using DataFrames @@ -62,7 +63,7 @@ Even though the role of a variable (treatment, outcome, confounder, ...) is rela The modeling stage starts from the definition of a Structural Causal Model (`SCM`). This is simply a list of Structural Equations (`SE`) describing the relationships between the random variables associated with our problem. See [Structural Causal Models](@ref) for an in-depth explanation. For our purposes, we will simply define it as follows: -```@example user-guide +```@example walk-through scm = SCM( SE(:Y, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], with_encoder(LinearRegressor())), SE(:T₁, [:W₁₁, :W₁₂], LogisticClassifier()), @@ -83,7 +84,7 @@ scm = SCM( From the previous causal model we can ask multiple causal questions which are each represented by a distinct estimand. The set of available estimands types can be listed as follow: -```@example user-guide +```@example walk-through AVAILABLE_ESTIMANDS ``` @@ -91,7 +92,7 @@ At the moment there are 3 main estimand types we can estimate in TMLE.jl, we pro - The Interventional Conditional Mean (see: TODO): -```@example user-guide +```@example walk-through cm = CM( scm, outcome=:Y, @@ -101,7 +102,7 @@ cm = CM( - The Average Treatment Effect (see: TODO): -```@example user-guide +```@example walk-through ate = ATE( scm, outcome=:Y, @@ -116,7 +117,7 @@ marginal_ate_t1 = ATE( - The Interaction Average Treatment Effect (see: TODO): -```@example user-guide +```@example walk-through iate = IATE( scm, outcome=:Y, @@ -128,7 +129,7 @@ iate = IATE( Then each parameter can be estimated by calling the `tmle` function. For example: -```@example user-guide +```@example walk-through result, _ = tmle!(cm, dataset) result ``` @@ -141,7 +142,7 @@ The `result` contains 3 main elements: The adjustment set is determined by the provided `adjustment_method` keyword. At the moment, only `BackdoorAdjustment` is available. However one can specify that extra covariates could be used to fit the outcome model. -```@example user-guide +```@example walk-through result, _ = tmle!(iate, dataset;adjustment_method=BackdoorAdjustment([:C])) result ``` @@ -150,6 +151,6 @@ result Because the TMLE and OSE are asymptotically linear estimators, they asymptotically follow a Normal distribution. This means one can perform standard T tests of null hypothesis. TMLE.jl extends the method provided by the [HypothesisTests.jl](https://juliastats.org/HypothesisTests.jl/stable/) package that can be used as follows. -```@example user-guide +```@example walk-through OneSampleTTest(tmle(result)) ``` From f666975e645162b55e15c1513e9ea187dfb515a9 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 10 Aug 2023 18:59:04 +0100 Subject: [PATCH 034/151] fix some constructors ambiguity --- docs/make.jl | 14 +++++---- examples/double_robustness.jl | 11 ++++--- src/estimands.jl | 7 +++-- test/3points_interactions.jl | 2 +- test/composition.jl | 10 +++---- test/double_robustness_ate.jl | 10 +++---- test/double_robustness_iate.jl | 6 ++-- test/estimands.jl | 54 +++++++++++++++++----------------- test/estimation.jl | 20 ++++++------- test/missing_management.jl | 2 +- test/non_regression_test.jl | 2 +- test/utils.jl | 18 ++++++------ 12 files changed, 79 insertions(+), 77 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index a7405ebd..ae3c7c60 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -5,12 +5,14 @@ using Literate DocMeta.setdocmeta!(TMLE, :DocTestSetup, :(using TMLE); recursive=true) ## Generate Literate markdown pages -# @info "Building Literate pages..." -# examples_dir = joinpath(@__DIR__, "../examples") -# build_examples_dir = joinpath(@__DIR__, "src", "examples/") -# for file in readdir(examples_dir) -# Literate.markdown(joinpath(examples_dir, file), build_examples_dir;documenter=true) -# end +@info "Building Literate pages..." +examples_dir = joinpath(@__DIR__, "../examples") +build_examples_dir = joinpath(@__DIR__, "src", "examples/") +for file in readdir(examples_dir) + if endswith(file, "jl") + Literate.markdown(joinpath(examples_dir, file), build_examples_dir;documenter=true) + end +end @info "Running makedocs..." makedocs(; diff --git a/examples/double_robustness.jl b/examples/double_robustness.jl index d9de26c6..bda768cc 100644 --- a/examples/double_robustness.jl +++ b/examples/double_robustness.jl @@ -154,13 +154,12 @@ that we now have full coverage of the ground truth. function tmle_inference(data) Ψ = ATE( - outcome=:Y, - treatment=(Tcat=(case=1.0, control=0.0),), - confounders=[:W] + :Y, + (Tcat=(case=1.0, control=0.0),), + [:W] ) - η_spec = NuisanceSpec(LinearRegressor(), LinearBinaryClassifier()) - result, _ = tmle!(Ψ, η_spec, data; verbosity=0) - tmleresult = result.tmle + result, _ = tmle!(Ψ, data; verbosity=0) + tmleresult = tmle(result) lb, ub = confint(OneSampleTTest(tmleresult)) return (TMLE.estimate(tmleresult), lb, ub) end diff --git a/src/estimands.jl b/src/estimands.jl index b8678a0e..0f726232 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -134,11 +134,12 @@ for typename in CMCompositeTypenames ex = quote name(::Type{$(typename)}) = string($(typename)) - $(typename)(;scm::SCM, outcome::Symbol, treatment::NamedTuple) = $(typename)(scm, outcome, treatment) $(typename)(scm::SCM; outcome::Symbol, treatment::NamedTuple) = $(typename)(scm, outcome, treatment) - function $(typename)( - outcome::Symbol, treatment::NamedTuple, confounders::Union{Symbol, AbstractVector{Symbol}}; + function $(typename)(; + outcome::Symbol, + treatment::NamedTuple, + confounders::Union{Symbol, AbstractVector{Symbol}}, covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, outcome_model = with_encoder(LinearRegressor()), treatment_model = LinearBinaryClassifier()) diff --git a/test/3points_interactions.jl b/test/3points_interactions.jl index 7eb6da1e..ceb91b6e 100644 --- a/test/3points_interactions.jl +++ b/test/3points_interactions.jl @@ -35,7 +35,7 @@ end @testset "Test 3-points interactions" begin dataset, scm, Ψ₀ = dataset_scm_and_truth(;n=1000) Ψ = IATE( - scm = scm, + scm, outcome = :Y, treatment = (T₁=(case=true, control=false), T₂=(case=true, control=false), T₃=(case=1, control=0)) ) diff --git a/test/composition.jl b/test/composition.jl index 04bd1e53..393a38d5 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -36,14 +36,14 @@ end dataset, scm = make_dataset_and_scm(;n=1000) # Conditional Mean T = 1 CM₁ = CM( - scm = scm, + scm, outcome = :Y, treatment = (T=1,) ) CM_result₁, _ = tmle!(CM₁, dataset, verbosity=0) # Conditional Mean T = 0 CM₀ = CM( - scm = scm, + scm, outcome = :Y, treatment = (T=0,) ) @@ -53,7 +53,7 @@ end # Via ATE ATE₁₀ = ATE( - scm=scm, + scm, outcome = :Y, treatment = (T=(case=1, control=0),), ) @@ -65,14 +65,14 @@ end @testset "Test compose multidimensional function" begin dataset, scm = make_dataset_and_scm(;n=1000) CM₁ = CM( - scm = scm, + scm, outcome = :Y, treatment = (T=1,) ) CM_result₁, _ = tmle!(CM₁, dataset, verbosity=0) CM₀ = CM( - scm = scm, + scm, outcome = :Y, treatment = (T=0,) ) diff --git a/test/double_robustness_ate.jl b/test/double_robustness_ate.jl index 1c7a5cf9..eca65cb0 100644 --- a/test/double_robustness_ate.jl +++ b/test/double_robustness_ate.jl @@ -116,7 +116,7 @@ end @testset "Test Double Robustness ATE on continuous_outcome_categorical_treatment_pb" begin dataset, scm, Ψ₀ = continuous_outcome_categorical_treatment_pb(;n=10_000, control="TT", case="AA") Ψ = ATE( - scm = scm, + scm, outcome = :Y, treatment = (T=(case="AA", control="TT"),), ) @@ -145,7 +145,7 @@ end @testset "Test Double Robustness ATE on binary_outcome_binary_treatment_pb" begin dataset, scm, Ψ₀ = binary_outcome_binary_treatment_pb(;n=10_000) Ψ = ATE( - scm = scm, + scm, outcome = :Y, treatment = (T=(case=true, control=false),), ) @@ -175,7 +175,7 @@ end @testset "Test Double Robustness ATE on continuous_outcome_binary_treatment_pb" begin dataset, scm, Ψ₀ = continuous_outcome_binary_treatment_pb(n=10_000) Ψ = ATE( - scm=scm, + scm, outcome = :Y, treatment = (T=(case=true, control=false),), ) @@ -205,7 +205,7 @@ end dataset, scm, (ATE₁₁₋₀₁, ATE₁₁₋₀₀) = dataset_2_treatments_pb(;rng = StableRNG(123), n=50_000) # Test first ATE, only T₁ treatment varies Ψ = ATE( - scm=scm, + scm, outcome = :Y, treatment = (T₁=(case=1., control=0.), T₂=(case=1., control=1.)), ) @@ -231,7 +231,7 @@ end # Test second ATE, two treatment varies Ψ = ATE( - scm = scm, + scm, outcome = :Y, treatment = (T₁=(case=1., control=0.), T₂=(case=1., control=0.)), ) diff --git a/test/double_robustness_iate.jl b/test/double_robustness_iate.jl index be1757b9..f08334d9 100644 --- a/test/double_robustness_iate.jl +++ b/test/double_robustness_iate.jl @@ -178,7 +178,7 @@ end @testset "Test Double Robustness IATE on binary_outcome_binary_treatment_pb" begin dataset, scm, Ψ₀ = binary_outcome_binary_treatment_pb(n=10_000) Ψ = IATE( - scm = scm, + scm, outcome=:Y, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), ) @@ -212,7 +212,7 @@ end @testset "Test Double Robustness IATE on continuous_outcome_binary_treatment_pb" begin dataset, scm,Ψ₀ = continuous_outcome_binary_treatment_pb(n=10_000) Ψ = IATE( - scm=scm, + scm, outcome=:Y, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), ) @@ -244,7 +244,7 @@ end @testset "Test Double Robustness IATE on binary_outcome_categorical_treatment_pb" begin dataset, scm, Ψ₀ = binary_outcome_categorical_treatment_pb(n=30_000) Ψ = IATE( - scm=scm, + scm, outcome=:Y, treatment=(T₁=(case="CC", control="CG"), T₂=(case="AT", control="AA")), ) diff --git a/test/estimands.jl b/test/estimands.jl index 3dfb1f51..a87983fd 100644 --- a/test/estimands.jl +++ b/test/estimands.jl @@ -39,7 +39,7 @@ using MLJGLMInterface ) Ψ = CM( - scm=scm, + scm, treatment=(T=1,), outcome =:Y ) @@ -55,7 +55,7 @@ using MLJGLMInterface SE(:T₂, [:W₂₁]), ) Ψ = ATE( - scm=scm, + scm, treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), outcome =:Z, ) @@ -64,7 +64,7 @@ using MLJGLMInterface # IATE Ψ = IATE( - scm=scm, + scm, treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), outcome =:Z, ) @@ -73,13 +73,13 @@ using MLJGLMInterface end @testset "Test constructors" begin - Ψ = CM(:Y, (T=1,), [:W]) + Ψ = CM(outcome=:Y, treatment=(T=1,), confounders=[:W]) @test parents(Ψ.scm.Y) == [:T, :W] - Ψ = ATE(:Y, (T₁=(case=1, control=0), T₂=(case=1, control=0)), :W) + Ψ = ATE(outcome=:Y, treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), confounders=:W) @test parents(Ψ.scm.Y) == [:T₁, :T₂, :W] - Ψ = IATE(:Y, (T₁=(case=1, control=0), T₂=(case=1, control=0)), :W) + Ψ = IATE(outcome=:Y, treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), confounders=:W) @test parents(Ψ.scm.Y) == [:T₁, :T₂, :W] end @@ -108,7 +108,7 @@ end # Fits CM required equations Ψ = CM( - scm=scm, + scm, treatment=(T₁=1,), outcome=:Ycat ) @@ -122,7 +122,7 @@ end # Fits ATE required equations that haven't been fitted yet. Ψ = ATE( - scm=scm, + scm, treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), outcome =:Ycat, ) @@ -139,7 +139,7 @@ end # Fits IATE required equations that haven't been fitted yet. Ψ = IATE( - scm=scm, + scm, treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), outcome =:Ycont, ) @@ -170,47 +170,47 @@ end ) estimands = [ ATE( - scm=scm, + scm, outcome=:Y₁, treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), ), IATE( - scm=scm, + scm, outcome=:Y₁, treatment=(T₁=(case=1, control=0), T₂=(case="AA", control="CC")), ), ATE( - scm=scm, + scm, outcome=:Y₁, treatment=(T₁=(case=1, control=0),), ), ATE( - scm=scm, + scm, outcome=:Y₂, treatment=(T₂=(case="AC", control="CC"),), ), CM( - scm=scm, + scm, outcome=:Y₂, treatment=(T₂="AC",), ), IATE( - scm=scm, + scm, outcome=:Y₁, treatment=(T₁=(case=0, control=1), T₂=(case="AC", control="CC"),), ), CM( - scm=scm, + scm, outcome=:Y₂, treatment=(T₂="CC",), ), ATE( - scm=scm, + scm, outcome=:Y₂, treatment=(T₂=(case="AA", control="CC"),), ), ATE( - scm=scm, + scm, outcome=:Y₂, treatment=(T₂=(case="AA", control="AC"),), ), @@ -223,16 +223,16 @@ end ordered_estimands = optimize_ordering(estimands) expected_ordering = [ # Y₁ - ATE(scm=scm, outcome=:Y₁, treatment=(T₁ = (case = 1, control = 0),)), - ATE(scm=scm, outcome=:Y₁, treatment=(T₁ = (case = 1, control = 0),T₂ = (case = "AC", control = "CC"))), - IATE(scm=scm, outcome=:Y₁, treatment=(T₁ = (case = 0, control = 1), T₂ = (case = "AC", control = "CC"),)), - IATE(scm=scm, outcome=:Y₁, treatment=(T₁ = (case = 1, control = 0),T₂ = (case = "AA", control = "CC"))), + ATE(scm, outcome=:Y₁, treatment=(T₁ = (case = 1, control = 0),)), + ATE(scm, outcome=:Y₁, treatment=(T₁ = (case = 1, control = 0),T₂ = (case = "AC", control = "CC"))), + IATE(scm, outcome=:Y₁, treatment=(T₁ = (case = 0, control = 1), T₂ = (case = "AC", control = "CC"),)), + IATE(scm, outcome=:Y₁, treatment=(T₁ = (case = 1, control = 0),T₂ = (case = "AA", control = "CC"))), # Y₂ - ATE(scm=scm, outcome=:Y₂, treatment=(T₂ = (case = "AA", control = "AC"),)), - CM(scm=scm, outcome=:Y₂, treatment=(T₂ = "CC",)), - ATE(scm=scm, outcome=:Y₂, treatment=(T₂ = (case = "AA", control = "CC"),)), - CM(scm=scm, outcome=:Y₂, treatment=(T₂ = "AC",)), - ATE(scm=scm, outcome=:Y₂, treatment=(T₂ = (case = "AC", control = "CC"),)), + ATE(scm, outcome=:Y₂, treatment=(T₂ = (case = "AA", control = "AC"),)), + CM(scm, outcome=:Y₂, treatment=(T₂ = "CC",)), + ATE(scm, outcome=:Y₂, treatment=(T₂ = (case = "AA", control = "CC"),)), + CM(scm, outcome=:Y₂, treatment=(T₂ = "AC",)), + ATE(scm, outcome=:Y₂, treatment=(T₂ = (case = "AC", control = "CC"),)), ] @test ordered_estimands == expected_ordering # Mutating function diff --git a/test/estimation.jl b/test/estimation.jl index 2178a744..339367e4 100644 --- a/test/estimation.jl +++ b/test/estimation.jl @@ -62,7 +62,7 @@ table_types = (Tables.columntable, DataFrame) dataset = tt(dataset) # Define the estimand of interest Ψ = ATE( - scm=scm, + scm, outcome=:Y₁, treatment=(T₁=(case=true, control=false),), ) @@ -85,7 +85,7 @@ table_types = (Tables.columntable, DataFrame) # Update the treatment specification # Nuisance estimands should not fitted again Ψ = ATE( - scm=scm, + scm, outcome=:Y₁, treatment=(T₁=(case=false, control=true),), ) @@ -104,7 +104,7 @@ table_types = (Tables.columntable, DataFrame) # New marginal treatment estimation: T₂ # This will trigger fit of the corresponding equations Ψ = ATE( - scm=scm, + scm, outcome=:Y₁, treatment=(T₂=(case=true, control=false),), ) @@ -125,7 +125,7 @@ table_types = (Tables.columntable, DataFrame) # Change the outcome: Y₂ # This will trigger the fit for the corresponding equation Ψ = ATE( - scm=scm, + scm, outcome=:Y₂, treatment=(T₂=(case=true, control=false),), ) @@ -144,7 +144,7 @@ table_types = (Tables.columntable, DataFrame) # New estimand with 2 treatment variables Ψ = ATE( - scm=scm, + scm, outcome=:Y₁, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), ) @@ -163,7 +163,7 @@ table_types = (Tables.columntable, DataFrame) # Switching the T₂ setting Ψ = ATE( - scm=scm, + scm, outcome=:Y₁, treatment=(T₁=(case=true, control=false), T₂=(case=false, control=true)), ) @@ -185,7 +185,7 @@ end dataset = tt(dataset) Ψ = CM( - scm=scm, + scm, outcome=:Y₁, treatment=(T₁=true, T₂=true), ) @@ -206,7 +206,7 @@ end # Let's switch case and control for T₂ Ψ = CM( - scm=scm, + scm, outcome=:Y₁, treatment=(T₁=true, T₂=false), ) @@ -224,7 +224,7 @@ end # Change the outcome Ψ = CM( - scm=scm, + scm, outcome=:Y₂, treatment=(T₂=false, ), ) @@ -260,7 +260,7 @@ end dataset, scm = build_dataset_and_scm(;n=50_000) dataset = tt(dataset) Ψ = IATE( - scm=scm, + scm, outcome=:Y₃, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), ) diff --git a/test/missing_management.jl b/test/missing_management.jl index 297c69ac..7211973a 100644 --- a/test/missing_management.jl +++ b/test/missing_management.jl @@ -44,7 +44,7 @@ end @testset "Test estimation with missing values and ordered factor treatment" begin dataset, scm = dataset_with_missing_and_ordered_treatment(;n=1000) Ψ = ATE( - scm = scm, + scm, outcome=:Y, treatment=(T=(case=1, control=0),)) tmle_result, fluctuation_mach = tmle!(Ψ, dataset; verbosity=0) diff --git a/test/non_regression_test.jl b/test/non_regression_test.jl index 17653073..49a2fe8e 100644 --- a/test/non_regression_test.jl +++ b/test/non_regression_test.jl @@ -21,7 +21,7 @@ using MLJGLMInterface outcome_model = TreatmentTransformer() |> LinearBinaryClassifier(), treatment_model = LinearBinaryClassifier() ) - Ψ = ATE(scm=scm, outcome=:haz01, treatment=(parity01=(case=1, control=0),)) + Ψ = ATE(scm, outcome=:haz01, treatment=(parity01=(case=1, control=0),)) result, fluctuation_mach = tmle!(Ψ, dataset, threshold=0.025, verbosity=0) l, u = confint(OneSampleTTest(result.tmle)) diff --git a/test/utils.jl b/test/utils.jl index 4efc3cb2..1379b6a4 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -24,14 +24,14 @@ using MLJModels SE(:Y, [:T, :W₁]) ) Ψ = ATE( - scm=scm, + scm, outcome=:Y, treatment=(T=(case=1, control=0),) ) @test_throws TMLE.AbsentLevelError("T", "control", 0, [1, 2, 3]) TMLE.check_treatment_levels(Ψ, dataset) Ψ = CM( - scm=scm, + scm, outcome=:Y, treatment=(T=0,), ) @@ -53,7 +53,7 @@ end ) # Conditional Mean Ψ = CM( - scm=scm, + scm, outcome=:Y, treatment=(T₁="A", T₂=1), ) @@ -63,7 +63,7 @@ end @test indic_values == [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0] # ATE Ψ = ATE( - scm=scm, + scm, outcome=:Y, treatment=(T₁=(case="A", control="B"), T₂=(control=0, case=1)), ) @@ -76,7 +76,7 @@ end @test indic_values == [0.0, -1.0, 1.0, 0.0, 0.0, -1.0, 1.0, 0.0] # 2-points IATE Ψ = IATE( - scm=scm, + scm, outcome=:Y, treatment=(T₁=(case="A", control="B"), T₂=(case=1, control=0)), ) @@ -91,7 +91,7 @@ end @test indic_values == [-1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0] # 3-points IATE Ψ = IATE( - scm=scm, + scm, outcome=:Y, treatment=(T₁=(case="A", control="B"), T₂=(case=1, control=0), T₃=(control="D", case="C")), ) @@ -163,7 +163,7 @@ end outcome_model = ConstantRegressor() ) Ψ = ATE( - scm=scm, + scm, outcome=:Y, treatment=(T=(case="a", control="b"),), ) @@ -214,7 +214,7 @@ end outcome_model = ConstantClassifier() ) Ψ = IATE( - scm=scm, + scm, outcome =:Y, treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), ) @@ -258,7 +258,7 @@ end outcome_model = DeterministicConstantRegressor() ) Ψ = IATE( - scm=scm, + scm, outcome =:Y, treatment=(T₁=(case="a", control="b"), T₂=(case=1, control=2), From 19e27ee8f4c418b9a96142c5a55223e9b80af81e Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 11 Aug 2023 00:23:56 +0100 Subject: [PATCH 035/151] update docs --- docs/src/index.md | 3 +-- docs/src/user_guide/adjustment.md | 4 +--- docs/src/user_guide/estimation.md | 5 +++++ docs/src/user_guide/misc.md | 3 +-- docs/src/user_guide/scm.md | 20 +++++++++----------- docs/src/walk_through.md | 29 ++++++++++++++++++++++------- src/TMLE.jl | 2 +- src/estimate.jl | 9 ++++++++- 8 files changed, 48 insertions(+), 27 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index 6d02e824..f7a675fa 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -59,7 +59,7 @@ nothing # hide Estimating the Average Treatment Effect can of ``T`` on ``Y`` can be as simple as: ```@example quick-start -Ψ = ATE(:Y, (T=(case=true, control = false),), :W) +Ψ = ATE(outcome=:Y, treatment=(T=(case=true, control = false),), confounders=:W) result, _ = tmle!(Ψ, dataset) result ``` @@ -80,7 +80,6 @@ and second, define the Average Treatment Effect of the treatment ``T`` on the ou outcome = :Y, treatment = (T=(case=true, control = false),), ) -nothing # hide ``` Note that in this example the ATE can be computed exactly and is given by: diff --git a/docs/src/user_guide/adjustment.md b/docs/src/user_guide/adjustment.md index d0d87293..712d924e 100644 --- a/docs/src/user_guide/adjustment.md +++ b/docs/src/user_guide/adjustment.md @@ -9,8 +9,6 @@ In a `SCM`, each variable is determined by a set of parents and a statistical mo At the moment we provide a single adjustment method, namely the Backdoor adjustment method. The adjustment set consists of all the treatment variable's parents. Additional covariates used to fit the outcome model can be provided via `outcome_extra`. -```@example -using TMLE # hide +```julia BackdoorAdjustment(;outcome_extra=[:C]) -nothing ``` diff --git a/docs/src/user_guide/estimation.md b/docs/src/user_guide/estimation.md index 75dd7d4f..8761ae1d 100644 --- a/docs/src/user_guide/estimation.md +++ b/docs/src/user_guide/estimation.md @@ -14,6 +14,7 @@ using StableRNGs using CategoricalArrays using TMLE using LogExpFunctions +using MLJLinearModels function make_dataset(;n=1000) rng = StableRNG(123) @@ -59,6 +60,7 @@ result₁, fluctuation_mach = tmle!(Ψ₁, dataset; threshold=1e-8, weighted_fluctuation=false ) +nothing # hide ``` We see that both models corresponding to variables `Y` and `T₁` were fitted in the process but that the model for `T₂` was not because it was not necessary to estimate this estimand. @@ -92,6 +94,7 @@ result₂, fluctuation_mach = tmle!(Ψ₂, dataset; threshold=1e-8, weighted_fluctuation=false ) +nothing # hide ``` The model for `T₂` was fitted in the process but so was the model for `Y` 🤔. This is because the `BackdoorAdjustment` method determined that the set of inputs for `Y` were different in both cases. @@ -109,6 +112,7 @@ result₃, fluctuation_mach = tmle!(Ψ₃, dataset; threshold=1e-8, weighted_fluctuation=false ) +nothing # hide ``` This time only the statistical model for `Y` is fitted again while reusing the models for `T₁` and `T₂`. Finally, let's see what happens if we estimate the `IATE` between `T₁` and `T₂`. @@ -122,6 +126,7 @@ result₄, fluctuation_mach = tmle!(Ψ₄, dataset; threshold=1e-8, weighted_fluctuation=false ) +nothing # hide ``` All statistical models have been reused 😊! diff --git a/docs/src/user_guide/misc.md b/docs/src/user_guide/misc.md index 8fced69c..ed47dd5c 100644 --- a/docs/src/user_guide/misc.md +++ b/docs/src/user_guide/misc.md @@ -9,8 +9,7 @@ To account for the fact that treatment variables are categorical variables we pr Such transformer can be created with: -```@example -using TMLE # hide +```julia TreatmentTransformer(;encoder=encoder()) ``` diff --git a/docs/src/user_guide/scm.md b/docs/src/user_guide/scm.md index 53e56efa..59a15354 100644 --- a/docs/src/user_guide/scm.md +++ b/docs/src/user_guide/scm.md @@ -10,20 +10,21 @@ In TMLE.jl, everything starts from the definition of a Structural Causal Model ( All models are wrong? Well maybe not the following: -```@example scm-incremental +```@example scm using TMLE # hide scm = SCM() ``` This model does not say anything about the random variables and is thus not really useful. Let's assume that we are interested in an outcome ``Y`` and that this outcome is determined by 8 other random variables. We can add this assumption to the model -```@example scm-incremental +```@example scm push!(scm, SE(:Y, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :W, :C])) ``` At this point, we haven't made any assumption regarding the functional form of the relationship between ``Y`` and its parents. We can add a further assumption by setting a statistical model for ``Y``, suppose we know it is generated from a logistic model, we can make that explicit: -```@example scm-incremental +```@example scm +using MLJLinearModels setmodel!(scm.Y, LogisticClassifier()) ``` @@ -32,14 +33,14 @@ setmodel!(scm.Y, LogisticClassifier()) - TMLE.jl is based on the main machine-learning framework in Julia: [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/). As such, any model respecting the MLJ interface is a valid model in TMLE.jl. - In real world scenarios, we usually don't know what is the true statistical model for each variable and want to keep it as large as possible. For this reason it is recommended to use Super-Learning which is implemented in MLJ by the [Stack](https://alan-turing-institute.github.io/MLJ.jl/dev/model_stacking/#Model-Stacking) and comes with theoretical properties. -- In the dataset, treatment variables are represented with [categorical data](https://alan-turing-institute.github.io/MLJ.jl/dev/working_with_categorical_data/). This means the models that depend on such variables will need to properly deal with them. For this purpose we provide a [TreatmentTransformer](@ref) which can easily be combined with any `model` in a [Pipelining](https://alan-turing-institute.github.io/MLJ.jl/dev/linear_pipelines/) flavour with `with_encoder(model)`. +- In the dataset, treatment variables are represented with [categorical data](https://alan-turing-institute.github.io/MLJ.jl/dev/working_with_categorical_data/). This means the models that depend on such variables will need to properly deal with them. For this purpose we provide a `TreatmentTransformer` which can easily be combined with any `model` in a [Pipelining](https://alan-turing-institute.github.io/MLJ.jl/dev/linear_pipelines/) flavour with `with_encoder(model)`. - The `SCM` has no knowledge of the data and thus cannot verify that the assumed statistical model is compatible with the data. This is done at a later stage. --- Let's now assume that we have a more complete knowledge of the problem and we also know how `T₁` and `T₂` depend on the rest of the variables in the system. -```@example scm-incremental +```@example scm push!(scm, SE(:T₁, [:W₁₁, :W₁₂, :W], model=LogisticClassifier())) push!(scm, SE(:T₂, [:W₂₁, :W₂₂, :W])) ``` @@ -48,8 +49,7 @@ push!(scm, SE(:T₂, [:W₂₁, :W₂₂, :W])) Instead of constructing the `SCM` incrementally, one can provide all the specified equations at once: -```@example scm-one-step -using TMLE # hide +```@example scm scm = SCM( SE(:Y, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :W, :C], with_encoder(LinearRegressor())), SE(:T₁, [:W₁₁, :W₁₂, :W], model=LogisticClassifier()), @@ -63,8 +63,7 @@ Noting that we have used the `with_encoder` function to reflect the fact that we There are many cases where we are interested in estimating the causal effect of a single treatment variable on a single outcome. Because it is typically only necessary to adjust for backdoor variables in order to identify this causal effect, we provide the `StaticConfoundedModel` interface to build such `SCM`: -```@example static-scm-1 -using TMLE # hide +```@example scm scm = StaticConfoundedModel( :Y, :T, [:W₁, :W₂]; covariates=[:C], @@ -77,8 +76,7 @@ The optional `covariates` are variables that influence the outcome but are not c This model can be extended to a plate-model with multiple treatments and multiple outcomes. In this case the set of confounders is assumed to confound all treatments which are in turn assumed to impact all outcomes. This can be defined as: -```@example static-scm-2 -using TMLE # hide +```@example scm scm = StaticConfoundedModel( [:Y₁, :Y₂], [:T₁, :T₂], [:W₁, :W₂]; covariates=[:C], diff --git a/docs/src/walk_through.md b/docs/src/walk_through.md index a7219788..e0179176 100644 --- a/docs/src/walk_through.md +++ b/docs/src/walk_through.md @@ -23,6 +23,7 @@ using StableRNGs using CategoricalArrays using TMLE using LogExpFunctions +using MLJLinearModels function make_dataset(;n=1000) rng = StableRNG(123) @@ -90,7 +91,7 @@ AVAILABLE_ESTIMANDS At the moment there are 3 main estimand types we can estimate in TMLE.jl, we provide below a few examples. -- The Interventional Conditional Mean (see: TODO): +- The Interventional Conditional Mean: ```@example walk-through cm = CM( @@ -100,10 +101,10 @@ cm = CM( ) ``` -- The Average Treatment Effect (see: TODO): +- The Average Treatment Effect: ```@example walk-through -ate = ATE( +total_ate = ATE( scm, outcome=:Y, treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)) @@ -115,7 +116,7 @@ marginal_ate_t1 = ATE( ) ``` -- The Interaction Average Treatment Effect (see: TODO): +- The Interaction Average Treatment Effect: ```@example walk-through iate = IATE( @@ -127,7 +128,7 @@ iate = IATE( ## Targeted Estimation -Then each parameter can be estimated by calling the `tmle` function. For example: +Then each parameter can be estimated by calling the `tmle!` function. For example: ```@example walk-through result, _ = tmle!(cm, dataset) @@ -136,10 +137,24 @@ result The `result` contains 3 main elements: -- The `TMLEEstimate` than can be accessed via: `tmle(result)`. -- The `OSEstimate` than can be accessed via: `ose(result)`. +- The `TMLEEstimate` than can be accessed via: + +```@example walk-through +tmle(result) +``` + +- The `OSEstimate` than can be accessed via: + +```@example walk-through +ose(result) +``` + - The naive initial estimate. +```@example walk-through +naive(result) +``` + The adjustment set is determined by the provided `adjustment_method` keyword. At the moment, only `BackdoorAdjustment` is available. However one can specify that extra covariates could be used to fit the outcome model. ```@example walk-through diff --git a/src/TMLE.jl b/src/TMLE.jl index e9cf2fbe..dd566ffa 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -31,7 +31,7 @@ export AverageTreatmentEffect, ATE export InteractionAverageTreatmentEffect, IATE export AVAILABLE_ESTIMANDS export fit!, optimize_ordering, optimize_ordering! -export tmle!, tmle, ose +export tmle!, tmle, ose, naive export var, estimate, initial_estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose export TreatmentTransformer, with_encoder diff --git a/src/estimate.jl b/src/estimate.jl index ee818fc7..da7f767c 100644 --- a/src/estimate.jl +++ b/src/estimate.jl @@ -23,7 +23,13 @@ struct TMLEResult{P <: Estimand, T<:AbstractFloat} initial::T end -function Base.show(io::IO, r::TMLEResult) +function Base.show(io::IO, ::MIME"text/plain", est::AsymptoticallyLinearEstimate) + testresult = OneSampleTTest(est) + data = [estimate(est) confint(testresult) pvalue(testresult);] + pretty_table(io, data;header=["Estimate", "95% Confidence Interval", "P-value"]) +end + +function Base.show(io::IO, ::MIME"text/plain", r::TMLEResult) tmletest = OneSampleTTest(r.tmle) onesteptest = OneSampleTTest(r.onestep) data = [ @@ -36,6 +42,7 @@ end tmle(result::TMLEResult) = result.tmle ose(result::TMLEResult) = result.onestep +initial(result::TMLEResult) = result.initial """ Distributions.estimate(r::AsymptoticallyLinearEstimate) From 1ae36d4ec1e4d11e653752843bd159efd4ed9453 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 11 Aug 2023 00:31:00 +0100 Subject: [PATCH 036/151] update examples --- examples/double_robustness.jl | 6 +++--- examples/super_learning.jl | 27 +++++++++++---------------- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/examples/double_robustness.jl b/examples/double_robustness.jl index bda768cc..d2a44323 100644 --- a/examples/double_robustness.jl +++ b/examples/double_robustness.jl @@ -154,9 +154,9 @@ that we now have full coverage of the ground truth. function tmle_inference(data) Ψ = ATE( - :Y, - (Tcat=(case=1.0, control=0.0),), - [:W] + outcome=:Y, + treatment=(Tcat=(case=1.0, control=0.0),), + confounders=[:W] ) result, _ = tmle!(Ψ, data; verbosity=0) tmleresult = tmle(result) diff --git a/examples/super_learning.jl b/examples/super_learning.jl index 02b9ac71..31cb2618 100644 --- a/examples/super_learning.jl +++ b/examples/super_learning.jl @@ -179,37 +179,32 @@ nothing # hide Let us move to the targeted estimation step itself. We define the target estimand (the ATE) and the nuisance estimands specification: =# Ψ = ATE( - outcome = :Y, - treatment = (T=(case=true, control=false),), - confounders = [:W₁, :W₂] + outcome=:Y, + treatment=(T=(case=true, control=false),), + confounders=[:W₁, :W₂], + outcome_model=Q_super_learner, + treatment_model=G_super_learner ) -η_spec = NuisanceSpec( - Q_super_learner, - G_super_learner -) nothing # hide #= Finally run the TMLE procedure and check the result =# -tmle_result, cache = tmle!Ψ, η_spec, dataset) +tmle_result, _ = tmle!(Ψ, dataset) -test_result = OneSampleTTest(tmle_result.tmle, ψ₀) +test_result = OneSampleTTest(tmle(tmle_result), ψ₀) #= Now, what if we had used linear models only instead of the Super Learner? This is easy to check =# +setmodel!(Ψ.scm.Y, LinearRegressor()) +setmodel!(Ψ.scm.T, LogisticClassifier()) -η_spec_linear = NuisanceSpec( - LinearRegressor(), - LogisticClassifier(lambda=0) -) - -tmle_result_linear, cache = tmle!Ψ, η_spec_linear, dataset) +tmle_result_linear, cache = tmle!(Ψ, dataset) -test_result_linear = OneSampleTTest(tmle_result_linear.tmle, ψ₀) +test_result_linear = OneSampleTTest(tmle(tmle_result_linear), ψ₀) # using Test # hide From b956972bff3b0815b04d6b8dc83ccb645b1efc19 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 11 Aug 2023 00:37:59 +0100 Subject: [PATCH 037/151] add extension --- Project.toml | 2 +- ... => introduction_to_targeted_learning.jl_} | 0 ext/GraphsTMLEExt.jl | 48 +++++++++++++++++++ test/Project.toml | 5 +- test/graphs_extension.jl | 20 ++++++++ 5 files changed, 73 insertions(+), 2 deletions(-) rename examples/{introduction_to_targeted_learning.jl => introduction_to_targeted_learning.jl_} (100%) create mode 100644 ext/GraphsTMLEExt.jl create mode 100644 test/graphs_extension.jl diff --git a/Project.toml b/Project.toml index 41e5c4d8..25b1fa86 100644 --- a/Project.toml +++ b/Project.toml @@ -50,5 +50,5 @@ YAML = "0.4" Zygote = "0.6" Graphs = "1.8.0" GraphMakie = "0.5.5" -CairoMakie = "0.10.7" +CairoMakie = "0.10" julia = "1.6, 1.7, 1" diff --git a/examples/introduction_to_targeted_learning.jl b/examples/introduction_to_targeted_learning.jl_ similarity index 100% rename from examples/introduction_to_targeted_learning.jl rename to examples/introduction_to_targeted_learning.jl_ diff --git a/ext/GraphsTMLEExt.jl b/ext/GraphsTMLEExt.jl new file mode 100644 index 00000000..c3e68b73 --- /dev/null +++ b/ext/GraphsTMLEExt.jl @@ -0,0 +1,48 @@ +module GraphsTMLEExt + +using TMLE +using Graphs +using CairoMakie +using GraphMakie + +function edges_list_and_nodes_mapping(scm::SCM) + edge_list = [] + node_ids = Dict() + node_id = 1 + for (outcome, eq) in equations(scm) + if outcome ∉ keys(node_ids) + push!(node_ids, outcome => node_id) + node_id += 1 + end + outcome_id = node_ids[outcome] + for parent in TMLE.parents(eq) + if parent ∉ keys(node_ids) + push!(node_ids, parent => node_id) + node_id += 1 + end + parent_id = node_ids[parent] + push!(edge_list, Edge(parent_id, outcome_id)) + end + end + return edge_list, node_ids +end + +function DAG(scm::SCM) + edges_list, nodes_mapping = edges_list_and_nodes_mapping(scm) + return SimpleDiGraphFromIterator(edges_list), nodes_mapping +end + +function graphplot(scm::SCM; kwargs...) + graph, nodes_mapping = DAG(scm) + f, ax, p = graphplot(graph; + ilabels = [x[1] for x in sort(collect(nodes_mapping), by=x -> x[2])], + kwargs...) + Legend( + f[1, 2], + [MarkerElement(color=:black, marker='⋆', markersize = 10) for eq in equations(scm)], + [TMLE.string_repr(eq) for (_, eq) in equations(scm)]) + hidedecorations!(ax); + hidespines!(ax); + ax.aspect = DataAspect() + return fig, ax, p +end \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index 7d219f1c..5633e74d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,10 +1,13 @@ [deps] CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" +GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" @@ -22,7 +25,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] CSV = "0.10" DataFrames = "1.5" -EvoTrees = "0.14" +EvoTrees = "0.15" MLJLinearModels = "0.9" StableRNGs = "1.0" StatsBase = "0.33" diff --git a/test/graphs_extension.jl b/test/graphs_extension.jl new file mode 100644 index 00000000..1324d15a --- /dev/null +++ b/test/graphs_extension.jl @@ -0,0 +1,20 @@ +module TestGraphsTMLEExt + +using Test +using TMLE +using Graphs +using CairoMakie +using GraphMakie + +@testset "Test DAG and graphplot" begin + scm = SCM( + SE(:Y, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C]), + SE(:T₁, [:W₁₁, :W₁₂]), + SE(:T₂, [:W₂₁, :W₂₂]), + ) + dag, nodes_mapping = TMLE.DAG(scm) +end + +end + +true \ No newline at end of file From 8a315e5f0b505df600b65f151f98536b11f702cd Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 11 Aug 2023 13:20:01 +0100 Subject: [PATCH 038/151] more updates --- Project.toml | 15 +- README.md | 12 +- docs/Manifest.toml | 911 ++++++++++++++++------------------ docs/Project.toml | 5 +- examples/double_robustness.jl | 1 - examples/super_learning.jl | 142 ++---- ext/GraphsTMLEExt.jl | 48 -- src/TMLE.jl | 99 ++-- src/estimate.jl | 12 +- test/graphs_extension.jl | 20 - test/non_regression_test.jl | 2 +- 11 files changed, 533 insertions(+), 734 deletions(-) delete mode 100644 ext/GraphsTMLEExt.jl delete mode 100644 test/graphs_extension.jl diff --git a/Project.toml b/Project.toml index 25b1fa86..fe015a7b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,15 +1,7 @@ name = "TMLE" uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf" authors = ["Olivier Labayle"] -version = "0.11.4" - -[weakdeps] -CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" -Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" - -[extensions] -GraphsTMLEExt = ["Graphs", "CairoMakie", "GraphMakie"] +version = "0.12.0" [deps] AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" @@ -36,7 +28,7 @@ AbstractDifferentiation = "0.4, 0.5" CategoricalArrays = "0.10" Distributions = "0.25" GLM = "1.8.2" -HypothesisTests = "0.10" +HypothesisTests = "0.10, 0.11" LogExpFunctions = "0.3" MLJBase = "0.19, 0.20, 0.21" MLJGLMInterface = "0.3.4" @@ -48,7 +40,4 @@ TableOperations = "1.2" Tables = "1.6" YAML = "0.4" Zygote = "0.6" -Graphs = "1.8.0" -GraphMakie = "0.5.5" -CairoMakie = "0.10" julia = "1.6, 1.7, 1" diff --git a/README.md b/README.md index b6088e0c..1740c70a 100644 --- a/README.md +++ b/README.md @@ -5,14 +5,6 @@ ![Codecov](https://img.shields.io/codecov/c/github/TARGENE/TMLE.jl/main) ![GitHub release (latest SemVer)](https://img.shields.io/github/v/release/TARGENE/TMLE.jl) -Targeted Minimum Loss-Based Estimation (TMLE) is a framework for efficient estimation of pathwise differentiable estimands. The goal of this package is to provide an implementation on top of the [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) framework. +This package enables the estimation of various Causal Inference related estimands using Targeted Minimum Loss-Based Estimation (TMLE). TMLE.jl is based on top of [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/), which means any MLJ compliant machine-learning model can be used here. -**New to TMLE.jl?** The API design is still work in progress but you can already [get started now](https://targene.github.io/TMLE.jl/stable/) - -## TODO - -- Check parameter definition -- Fix tests -- Simplify code -- More API -- Documentation update +**New to TMLE.jl?** [Get started now](https://targene.github.io/TMLE.jl/stable/) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index a6cc21a5..bdd9fff0 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.0" +julia_version = "1.9.1" manifest_format = "2.0" -project_hash = "fe6007b69052f33942368ff494d941f2273e4e34" +project_hash = "82d74b11546c28748c1342b9b9b370ce23d4dd37" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" @@ -15,38 +15,21 @@ git-tree-sha1 = "e8c8e0a2be6eb4f56b1672e46004463033daa409" uuid = "da404889-ca92-49ff-9e8b-0aa6b4d38dc8" version = "1.4.1" -[[deps.AbstractDifferentiation]] -deps = ["ExprTools", "LinearAlgebra", "Requires"] -git-tree-sha1 = "f83fd553acff1c6a7f5c4e6f5f2b5941d533cdc9" -uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" -version = "0.5.2" - - [deps.AbstractDifferentiation.extensions] - AbstractDifferentiationChainRulesCoreExt = "ChainRulesCore" - AbstractDifferentiationFiniteDifferencesExt = "FiniteDifferences" - AbstractDifferentiationForwardDiffExt = ["DiffResults", "ForwardDiff"] - AbstractDifferentiationReverseDiffExt = ["DiffResults", "ReverseDiff"] - AbstractDifferentiationTrackerExt = "Tracker" - AbstractDifferentiationZygoteExt = "Zygote" - - [deps.AbstractDifferentiation.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" - FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" - ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" - ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" - [[deps.AbstractFFTs]] deps = ["LinearAlgebra"] -git-tree-sha1 = "16b6dbc4cf7caee4e1e75c49485ec67b667098a0" +git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.3.1" -weakdeps = ["ChainRulesCore"] +version = "1.5.0" +weakdeps = ["ChainRulesCore", "Test"] [deps.AbstractFFTs.extensions] AbstractFFTsChainRulesCoreExt = "ChainRulesCore" + AbstractFFTsTestExt = "Test" + +[[deps.AbstractLattices]] +git-tree-sha1 = "f35684b7349da49fcc8a9e520e30e45dbb077166" +uuid = "398f06c4-4d28-53ec-89ca-5b2656b7603d" +version = "0.2.1" [[deps.AbstractTrees]] git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" @@ -55,9 +38,9 @@ version = "0.4.4" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "cc37d689f599e8df4f464b2fa3870ff7db7492ef" +git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.6.1" +version = "3.6.2" weakdeps = ["StaticArrays"] [deps.Adapt.extensions] @@ -74,10 +57,10 @@ uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" version = "1.1.1" [[deps.ArrayInterface]] -deps = ["Adapt", "LinearAlgebra", "Requires", "SnoopPrecompile", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "38911c7737e123b28182d89027f4216cfc8a9da7" +deps = ["Adapt", "LinearAlgebra", "Requires", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "f83ec24f76d4c8f525099b2ac475fc098138ec31" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.4.3" +version = "7.4.11" [deps.ArrayInterface.extensions] ArrayInterfaceBandedMatricesExt = "BandedMatrices" @@ -95,12 +78,6 @@ version = "7.4.3" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -[[deps.ArrayInterfaceCore]] -deps = ["LinearAlgebra", "SnoopPrecompile", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "e5f08b5689b1aad068e01751889f2f615c7db36d" -uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" -version = "0.1.29" - [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -111,10 +88,10 @@ uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" version = "0.1.0" [[deps.Automa]] -deps = ["Printf", "ScanByte", "TranscodingStreams"] -git-tree-sha1 = "d50976f217489ce799e366d9561d56a98a30d7fe" +deps = ["TranscodingStreams"] +git-tree-sha1 = "ef9997b3d5547c48b41c7bd8899e812a917b409d" uuid = "67c07d97-cdcb-5c2c-af73-a7f9c32a568b" -version = "0.8.2" +version = "0.8.4" [[deps.AxisAlgorithms]] deps = ["LinearAlgebra", "Random", "SparseArrays", "WoodburyMatrices"] @@ -124,9 +101,9 @@ version = "1.0.1" [[deps.AxisArrays]] deps = ["Dates", "IntervalSets", "IterTools", "RangeArrays"] -git-tree-sha1 = "1dd4d9f5beebac0c03446918741b1a03dc5e5788" +git-tree-sha1 = "16351be62963a67ac4083f748fdb3cca58bfd52f" uuid = "39de3d68-74b9-583c-8d2d-e117c070f3a9" -version = "0.4.6" +version = "0.4.7" [[deps.BFloat16s]] deps = ["LinearAlgebra", "Printf", "Random", "Test"] @@ -147,12 +124,6 @@ git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d" uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" version = "0.1.7" -[[deps.BitTwiddlingConvenienceFunctions]] -deps = ["Static"] -git-tree-sha1 = "0c5f81f47bbbcf4aea7b2959135713459170798b" -uuid = "62783981-4cbd-42fc-bca8-16325de8dc4b" -version = "0.1.5" - [[deps.Bzip2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "19a35467a82e236ff51bc17a3a44b69ef35185a2" @@ -164,20 +135,32 @@ git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.4.2" -[[deps.CPUSummary]] -deps = ["CpuId", "IfElse", "Static"] -git-tree-sha1 = "2c144ddb46b552f72d7eafe7cc2f50746e41ea21" -uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" -version = "0.2.2" - [[deps.CRC32c]] uuid = "8bf52ea8-c179-5cab-976a-9e18b702a9bc" +[[deps.CRlibm]] +deps = ["CRlibm_jll"] +git-tree-sha1 = "32abd86e3c2025db5172aa182b982debed519834" +uuid = "96374032-68de-5a5b-8d9e-752f78720389" +version = "1.0.1" + +[[deps.CRlibm_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "e329286945d0cfc04456972ea732551869af1cfc" +uuid = "4e9b3aee-d8a1-5a3d-ad8b-7d824db253f0" +version = "1.0.1+0" + +[[deps.CSV]] +deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] +git-tree-sha1 = "44dbf560808d49041989b8a96cae4cffbeb7966a" +uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" +version = "0.10.11" + [[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "280893f920654ebfaaaa1999fbd975689051f890" +deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "35160ef0f03b14768abfd68b830f8e3940e8e0dc" uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "4.2.0" +version = "4.4.0" [[deps.CUDA_Driver_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] @@ -204,10 +187,10 @@ uuid = "159f3aea-2a34-519c-b102-8c37f9878175" version = "1.0.5" [[deps.CairoMakie]] -deps = ["Base64", "Cairo", "Colors", "FFTW", "FileIO", "FreeType", "GeometryBasics", "LinearAlgebra", "Makie", "SHA", "SnoopPrecompile"] -git-tree-sha1 = "2aba202861fd2b7603beb80496b6566491229855" +deps = ["Base64", "Cairo", "Colors", "FFTW", "FileIO", "FreeType", "GeometryBasics", "LinearAlgebra", "Makie", "PrecompileTools", "SHA"] +git-tree-sha1 = "e041782fed7614b1726fa250f2bf24fd5c789689" uuid = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -version = "0.10.4" +version = "0.10.7" [[deps.Cairo_jll]] deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] @@ -251,29 +234,17 @@ version = "0.1.10" [deps.CategoricalDistributions.weakdeps] UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" -[[deps.ChainRules]] -deps = ["Adapt", "ChainRulesCore", "Compat", "Distributed", "GPUArraysCore", "IrrationalConstants", "LinearAlgebra", "Random", "RealDot", "SparseArrays", "Statistics", "StructArrays"] -git-tree-sha1 = "8bae903893aeeb429cf732cf1888490b93ecf265" -uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.49.0" - [[deps.ChainRulesCore]] deps = ["Compat", "LinearAlgebra", "SparseArrays"] git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" version = "1.16.0" -[[deps.CloseOpenIntervals]] -deps = ["Static", "StaticArrayInterface"] -git-tree-sha1 = "70232f82ffaab9dc52585e0dd043b5e0c6b714f1" -uuid = "fb6a15b2-703c-40df-9091-08a04967cfa9" -version = "0.1.12" - [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "9c209fb7536406834aa938fb149964b985de6c83" +git-tree-sha1 = "02aa26a4cf76381be7f66e020a3eddeb27b0a092" uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.1" +version = "0.7.2" [[deps.ColorBrewer]] deps = ["Colors", "JSON", "Test"] @@ -283,9 +254,9 @@ version = "0.4.0" [[deps.ColorSchemes]] deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "be6ab11021cd29f0344d5c4357b163af05a48cba" +git-tree-sha1 = "d9a8f86737b665e15a9641ecbac64deef9ce6724" uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.21.0" +version = "3.23.0" [[deps.ColorTypes]] deps = ["FixedPointNumbers", "Random"] @@ -310,11 +281,6 @@ git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" version = "1.0.2" -[[deps.CommonSolve]] -git-tree-sha1 = "9441451ee712d1aec22edad62db1a9af3dc8d852" -uuid = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2" -version = "0.2.3" - [[deps.CommonSubexpressions]] deps = ["MacroTools", "Test"] git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" @@ -323,9 +289,9 @@ version = "0.3.0" [[deps.Compat]] deps = ["UUIDs"] -git-tree-sha1 = "7a60c856b9fa189eb34f5f8a6f6b5529b7942957" +git-tree-sha1 = "e460f044ca8b99be31d35fe54fc33a5c33dd8ed7" uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.6.1" +version = "4.9.0" weakdeps = ["Dates", "LinearAlgebra"] [deps.Compat.extensions] @@ -343,21 +309,15 @@ version = "0.3.2" [[deps.ConcurrentUtilities]] deps = ["Serialization", "Sockets"] -git-tree-sha1 = "96d823b94ba8d187a6d8f0826e731195a74b90e9" +git-tree-sha1 = "5372dbbf8f0bdb8c700db5367132925c0771ef7e" uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" -version = "2.2.0" - -[[deps.Configurations]] -deps = ["ExproniconLite", "OrderedCollections", "TOML"] -git-tree-sha1 = "62a7c76dbad02fdfdaa53608104edf760938c4ca" -uuid = "5218b696-f38b-4ac9-8b61-a12ec717816d" -version = "0.17.4" +version = "2.2.1" [[deps.ConstructionBase]] deps = ["LinearAlgebra"] -git-tree-sha1 = "738fec4d684a9a6ee9598a8bfee305b26831f28c" +git-tree-sha1 = "fe2838a593b5f776e1597e086dcd47560d94e816" uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.2" +version = "1.5.3" weakdeps = ["IntervalSets", "StaticArrays"] [deps.ConstructionBase.extensions] @@ -369,12 +329,6 @@ git-tree-sha1 = "d05d9e7b7aedff4e5b51a029dced05cfb6125781" uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" version = "0.6.2" -[[deps.CpuId]] -deps = ["Markdown"] -git-tree-sha1 = "fcbb72b032692610bfbdb15018ac16a36cf2e406" -uuid = "adafc99b-e345-5852-983c-f28acb93d879" -version = "0.3.1" - [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" @@ -386,21 +340,21 @@ uuid = "a10d1c49-ce27-4219-8d33-6db1a4562965" version = "2.5.0" [[deps.DataAPI]] -git-tree-sha1 = "e8119c1a33d267e16108be441a287a6981ba1630" +git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.14.0" +version = "1.15.0" [[deps.DataFrames]] -deps = ["Compat", "DataAPI", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SnoopPrecompile", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "aa51303df86f8626a962fccb878430cdb0a97eee" +deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] +git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.5.0" +version = "1.6.1" [[deps.DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0" +git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d" uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.13" +version = "0.18.15" [[deps.DataValueInterfaces]] git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" @@ -411,6 +365,16 @@ version = "1.0.0" deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" +[[deps.DelaunayTriangulation]] +deps = ["DataStructures", "EnumX", "ExactPredicates", "Random", "SimpleGraphs"] +git-tree-sha1 = "a1d8532de83f8ce964235eff1edeff9581144d02" +uuid = "927a84f5-c5f4-47a5-9785-b46e178433df" +version = "0.7.2" +weakdeps = ["MakieCore"] + + [deps.DelaunayTriangulation.extensions] + DelaunayTriangulationMakieCoreExt = "MakieCore" + [[deps.DelimitedFiles]] deps = ["Mmap"] git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" @@ -425,25 +389,29 @@ version = "1.1.0" [[deps.DiffRules]] deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "a4ad7ef19d2cdc2eff57abbbe68032b1cd0bd8f8" +git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.13.0" +version = "1.15.1" [[deps.Distances]] -deps = ["LinearAlgebra", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "49eba9ad9f7ead780bfb7ee319f962c811c6d3b2" +deps = ["LinearAlgebra", "Statistics", "StatsAPI"] +git-tree-sha1 = "b6def76ffad15143924a2199f72a5cd883a2e8a9" uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.8" +version = "0.10.9" +weakdeps = ["SparseArrays"] + + [deps.Distances.extensions] + DistancesSparseArraysExt = "SparseArrays" [[deps.Distributed]] deps = ["Random", "Serialization", "Sockets"] uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" [[deps.Distributions]] -deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SparseArrays", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "eead66061583b6807652281c0fbf291d7a9dc497" +deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"] +git-tree-sha1 = "938fe2981db009f531b6332e31c58e9584a2f9bd" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.90" +version = "0.25.100" [deps.Distributions.extensions] DistributionsChainRulesCoreExt = "ChainRulesCore" @@ -461,9 +429,9 @@ version = "0.9.3" [[deps.Documenter]] deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "58fea7c536acd71f3eef6be3b21c0df5f3df88fd" +git-tree-sha1 = "39fd748a73dce4c05a9655475e437170d8fb1b67" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.27.24" +version = "0.27.25" [[deps.Downloads]] deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] @@ -488,28 +456,44 @@ git-tree-sha1 = "98fdf08b707aaf69f524a6cd0a67858cefe0cfb6" uuid = "792122b4-ca99-40de-a6bc-6742525f08b6" version = "0.3.0" +[[deps.EnumX]] +git-tree-sha1 = "bdb1942cd4c45e3c678fd11569d5cccd80976237" +uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +version = "1.0.4" + +[[deps.ErrorfreeArithmetic]] +git-tree-sha1 = "d6863c556f1142a061532e79f611aa46be201686" +uuid = "90fa49ef-747e-5e6f-a989-263ba693cf1a" +version = "0.5.2" + [[deps.EvoTrees]] -deps = ["BSON", "CUDA", "CategoricalArrays", "Distributions", "LoopVectorization", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "ae9f9fcfcf6d37bf04a3dae5c518732e31169a4d" +deps = ["BSON", "CUDA", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "1b418518c0eb1fd1ef0a6d0bfc8051e6abb1232b" uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" -version = "0.14.9" +version = "0.15.2" + +[[deps.ExactPredicates]] +deps = ["IntervalArithmetic", "Random", "StaticArraysCore", "Test"] +git-tree-sha1 = "276e83bc8b21589b79303b9985c321024ffdf59c" +uuid = "429591f6-91af-11e9-00e2-59fbe8cec110" +version = "2.2.5" + +[[deps.ExceptionUnwrapping]] +deps = ["Test"] +git-tree-sha1 = "e90caa41f5a86296e014e148ee061bd6c3edec96" +uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" +version = "0.1.9" [[deps.Expat_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "bad72f730e9e91c08d9427d5e8db95478a3c323d" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "4558ab818dcceaab612d1bb8c19cee87eda2b83c" uuid = "2e619515-83b5-522b-bb60-26c02a35a201" -version = "2.4.8+0" +version = "2.5.0+0" [[deps.ExprTools]] -git-tree-sha1 = "c1d06d129da9f55715c6c212866f5b1bddc5fa00" +git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.9" - -[[deps.ExproniconLite]] -deps = ["Pkg", "TOML"] -git-tree-sha1 = "c2eb763acf6e13e75595e0737a07a0bec0ce2147" -uuid = "55351af7-c7e9-48d6-89ff-24e801d99491" -version = "0.7.11" +version = "0.1.10" [[deps.Extents]] git-tree-sha1 = "5e1e4c53fa39afe63a7d356e30452249365fba99" @@ -523,16 +507,16 @@ uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a" version = "0.4.1" [[deps.FFMPEG_jll]] -deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Pkg", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] -git-tree-sha1 = "74faea50c1d007c85837327f6775bea60b5492dd" +deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] +git-tree-sha1 = "466d45dc38e15794ec7d5d63ec03d776a9aff36e" uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5" -version = "4.4.2+2" +version = "4.4.4+1" [[deps.FFTW]] deps = ["AbstractFFTs", "FFTW_jll", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"] -git-tree-sha1 = "f9818144ce7c8c41edf5c4c179c684d92aa4d9fe" +git-tree-sha1 = "b4fbdd20c889804969571cc589900803edda16b7" uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" -version = "1.6.0" +version = "1.7.1" [[deps.FFTW_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -540,26 +524,48 @@ git-tree-sha1 = "c6033cc3892d0ef5bb9cd29b7f2f0331ea5184ea" uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" version = "3.3.10+0" +[[deps.FastRounding]] +deps = ["ErrorfreeArithmetic", "LinearAlgebra"] +git-tree-sha1 = "6344aa18f654196be82e62816935225b3b9abe44" +uuid = "fa42c844-2597-5d31-933b-ebd51ab2693f" +version = "0.3.1" + [[deps.FileIO]] deps = ["Pkg", "Requires", "UUIDs"] git-tree-sha1 = "299dc33549f68299137e51e6d49a13b5b1da9673" uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" version = "1.16.1" +[[deps.FilePathsBase]] +deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] +git-tree-sha1 = "e27c4ebe80e8699540f2d6c805cc12203b614f12" +uuid = "48062228-2e41-5def-b9a4-89aafe57970f" +version = "0.9.20" + [[deps.FileWatching]] uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" [[deps.FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] -git-tree-sha1 = "7072f1e3e5a8be51d525d64f63d3ec1287ff2790" +git-tree-sha1 = "f372472e8672b1d993e93dada09e23139b509f9e" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.13.11" +version = "1.5.0" [[deps.FiniteDiff]] -deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays", "StaticArrays"] -git-tree-sha1 = "6604e18a0220650dbbea7854938768f15955dd8e" +deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays"] +git-tree-sha1 = "c6e4a1fbe73b31a3dea94b1da449503b8830c306" uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" -version = "2.20.0" +version = "2.21.1" + + [deps.FiniteDiff.extensions] + FiniteDiffBandedMatricesExt = "BandedMatrices" + FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices" + FiniteDiffStaticArraysExt = "StaticArrays" + + [deps.FiniteDiff.weakdeps] + BandedMatrices = "aae01518-5342-5314-be14-df237901396f" + BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" + StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [[deps.FixedPointNumbers]] deps = ["Statistics"] @@ -581,9 +587,9 @@ version = "0.4.2" [[deps.ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "00e252f4d706b3d55a8863432e742bf5717b498d" +git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.35" +version = "0.10.36" weakdeps = ["StaticArrays"] [deps.ForwardDiff.extensions] @@ -625,33 +631,33 @@ version = "1.8.3" [[deps.GPUArrays]] deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "9ade6983c3dbbd492cf5729f865fe030d1541463" +git-tree-sha1 = "2e57b4a4f9cc15e85a24d603256fe08e527f48d1" uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "8.6.6" +version = "8.8.1" [[deps.GPUArraysCore]] deps = ["Adapt"] -git-tree-sha1 = "1cd7f0af1aa58abc02ea1d872953a97359cb87fa" +git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0" uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.4" +version = "0.1.5" [[deps.GPUCompiler]] deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "e9a9173cd77e16509cdf9c1663fda19b22a518b7" +git-tree-sha1 = "72b2e3c2ba583d1a7aa35129e56cf92e07c083e3" uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.19.3" +version = "0.21.4" [[deps.GeoInterface]] deps = ["Extents"] -git-tree-sha1 = "0eb6de0b312688f852f347171aba888658e29f20" +git-tree-sha1 = "bb198ff907228523f3dee1070ceee63b9359b6ab" uuid = "cf35fbd7-0cd7-5166-be24-54bfbe79505f" -version = "1.3.0" +version = "1.3.1" [[deps.GeometryBasics]] -deps = ["EarCut_jll", "GeoInterface", "IterTools", "LinearAlgebra", "StaticArrays", "StructArrays", "Tables"] -git-tree-sha1 = "659140c9375afa2f685e37c1a0b9c9a60ef56b40" +deps = ["EarCut_jll", "Extents", "GeoInterface", "IterTools", "LinearAlgebra", "StaticArrays", "StructArrays", "Tables"] +git-tree-sha1 = "424a5a6ce7c5d97cca7bcc4eac551b97294c54af" uuid = "5c1252a2-5f33-56bf-86c9-59e7332b4326" -version = "0.4.7" +version = "0.4.9" [[deps.Gettext_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] @@ -679,9 +685,9 @@ version = "1.3.14+0" [[deps.GridLayoutBase]] deps = ["GeometryBasics", "InteractiveUtils", "Observables"] -git-tree-sha1 = "678d136003ed5bceaab05cf64519e3f956ffa4ba" +git-tree-sha1 = "f57a64794b336d4990d90f80b147474b869b1bc4" uuid = "3955a311-db13-416c-9275-1d80ed98e5e9" -version = "0.9.1" +version = "0.9.2" [[deps.Grisu]] git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" @@ -689,10 +695,10 @@ uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe" version = "1.0.2" [[deps.HTTP]] -deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "42c6b37c6a2242cb646a1dd5518631db0be9b967" +deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] +git-tree-sha1 = "cb56ccdd481c0dd7f975ad2b3b62d9eda088f7e2" uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.8.1" +version = "1.9.14" [[deps.HarfBuzz_jll]] deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] @@ -700,46 +706,23 @@ git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" version = "2.8.1+1" -[[deps.HostCPUFeatures]] -deps = ["BitTwiddlingConvenienceFunctions", "IfElse", "Libdl", "Static"] -git-tree-sha1 = "734fd90dd2f920a2f1921d5388dcebe805b262dc" -uuid = "3e5b6fbb-0976-4d2c-9146-d79de83f2fb0" -version = "0.1.14" - [[deps.HypergeometricFunctions]] deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] -git-tree-sha1 = "84204eae2dd237500835990bcade263e27674a93" +git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.16" - -[[deps.HypothesisTests]] -deps = ["Combinatorics", "Distributions", "LinearAlgebra", "Random", "Rmath", "Roots", "Statistics", "StatsBase"] -git-tree-sha1 = "fee0691e3336a71503dada09ed61bed786b0f59f" -uuid = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" -version = "0.10.13" +version = "0.3.23" [[deps.IOCapture]] deps = ["Logging", "Random"] -git-tree-sha1 = "f7be53659ab06ddc986428d3a9dcc95f6fa6705a" +git-tree-sha1 = "d75853a0bdbfb1ac815478bacd89cd27b550ace6" uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.2" - -[[deps.IRTools]] -deps = ["InteractiveUtils", "MacroTools", "Test"] -git-tree-sha1 = "eac00994ce3229a464c2847e956d77a2c64ad3a5" -uuid = "7869d1d1-7146-5819-86e3-90919afe41df" -version = "0.4.10" - -[[deps.IfElse]] -git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" -uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" -version = "0.1.1" +version = "0.2.3" [[deps.ImageAxes]] deps = ["AxisArrays", "ImageBase", "ImageCore", "Reexport", "SimpleTraits"] -git-tree-sha1 = "c54b581a83008dc7f292e205f4c409ab5caa0f04" +git-tree-sha1 = "2e4520d67b0cef90865b3ef727594d2a58e0e1f8" uuid = "2803e5a7-5153-5ecf-9a86-9b4c37f5f5ac" -version = "0.6.10" +version = "0.6.11" [[deps.ImageBase]] deps = ["ImageCore", "Reexport"] @@ -755,15 +738,15 @@ version = "0.9.4" [[deps.ImageIO]] deps = ["FileIO", "IndirectArrays", "JpegTurbo", "LazyModules", "Netpbm", "OpenEXR", "PNGFiles", "QOI", "Sixel", "TiffImages", "UUIDs"] -git-tree-sha1 = "342f789fd041a55166764c351da1710db97ce0e0" +git-tree-sha1 = "bca20b2f5d00c4fbc192c3212da8fa79f4688009" uuid = "82e4d734-157c-48bb-816b-45c225c6df19" -version = "0.6.6" +version = "0.6.7" [[deps.ImageMetadata]] deps = ["AxisArrays", "ImageAxes", "ImageBase", "ImageCore"] -git-tree-sha1 = "36cbaebed194b292590cba2593da27b34763804a" +git-tree-sha1 = "355e2b974f2e3212a75dfb60519de21361ad3cb7" uuid = "bc367c6b-8a6b-528e-b4bd-a4b897500b49" -version = "0.9.8" +version = "0.9.9" [[deps.Imath_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -787,11 +770,16 @@ git-tree-sha1 = "9cc2baf75c6d09f9da536ddf58eb2f29dedaf461" uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" version = "1.4.0" +[[deps.IntegerMathUtils]] +git-tree-sha1 = "b8ffb903da9f7b8cf695a8bead8e01814aa24b30" +uuid = "18e54dd8-cb9d-406c-a71d-865a43cbb235" +version = "0.1.2" + [[deps.IntelOpenMP_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "0cb9352ef2e01574eeebdb102948a58740dcaf83" +git-tree-sha1 = "ad37c091f7d7daf900963171600d7c1c5c3ede32" uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" -version = "2023.1.0+0" +version = "2023.2.0+0" [[deps.InteractiveUtils]] deps = ["Markdown"] @@ -803,11 +791,21 @@ git-tree-sha1 = "721ec2cf720536ad005cb38f50dbba7b02419a15" uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" version = "0.14.7" +[[deps.IntervalArithmetic]] +deps = ["CRlibm", "FastRounding", "LinearAlgebra", "Markdown", "Random", "RecipesBase", "RoundingEmulator", "SetRounding", "StaticArrays"] +git-tree-sha1 = "5ab7744289be503d76a944784bac3f2df7b809af" +uuid = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253" +version = "0.20.9" + [[deps.IntervalSets]] -deps = ["Dates", "Random", "Statistics"] -git-tree-sha1 = "16c0cc91853084cb5f58a78bd209513900206ce6" +deps = ["Dates", "Random"] +git-tree-sha1 = "8e59ea773deee525c99a8018409f64f19fb719e6" uuid = "8197267c-284f-5f27-9208-e0e47529a953" -version = "0.7.4" +version = "0.7.7" +weakdeps = ["Statistics"] + + [deps.IntervalSets.extensions] + IntervalSetsStatisticsExt = "Statistics" [[deps.InvertedIndices]] git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" @@ -826,9 +824,9 @@ uuid = "f1662d9f-8043-43de-a69a-05efc1cc6ff4" version = "0.1.1" [[deps.IterTools]] -git-tree-sha1 = "fa6287a4469f5e048d763df38279ee729fbd44e5" +git-tree-sha1 = "4ced6667f9974fc5c5943fa5e2ef1ca43ea9e450" uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e" -version = "1.4.0" +version = "1.8.0" [[deps.IterationControl]] deps = ["EarlyStopping", "InteractiveUtils"] @@ -861,9 +859,9 @@ version = "0.21.4" [[deps.JpegTurbo]] deps = ["CEnum", "FileIO", "ImageCore", "JpegTurbo_jll", "TOML"] -git-tree-sha1 = "106b6aa272f294ba47e96bd3acbabdc0407b5c60" +git-tree-sha1 = "327713faef2a3e5c80f96bf38d1fa26f7a6ae29e" uuid = "b835a17e-a41a-41e7-81f0-2f016b05efe0" -version = "0.1.2" +version = "0.1.3" [[deps.JpegTurbo_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl"] @@ -872,10 +870,16 @@ uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" version = "2.1.91+0" [[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "47be64f040a7ece575c2b5f53ca6da7b548d69f4" +deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] +git-tree-sha1 = "4c5875e4c228247e1c2b087669846941fb6e0118" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.4" +version = "0.9.8" + + [deps.KernelAbstractions.extensions] + EnzymeExt = "EnzymeCore" + + [deps.KernelAbstractions.weakdeps] + EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" [[deps.KernelDensity]] deps = ["Distributions", "DocStringExtensions", "FFTW", "Interpolations", "StatsBase"] @@ -891,15 +895,21 @@ version = "3.100.1+0" [[deps.LLVM]] deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "a8960cae30b42b66dd41808beb76490519f6f9e2" +git-tree-sha1 = "8695a49bfe05a2dc0feeefd06b4ca6361a018729" uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "5.0.0" +version = "6.1.0" [[deps.LLVMExtra_jll]] deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "09b7505cc0b1cee87e5d4a26eea61d2e1b0dcd35" +git-tree-sha1 = "c35203c1e1002747da220ffc3c0762ce7754b08c" uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.21+0" +version = "0.0.23+0" + +[[deps.LLVMOpenMP_jll]] +deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] +git-tree-sha1 = "f689897ccbe049adb19a065c495e75f372ecd42b" +uuid = "1d63c593-3942-5779-bab2-d838dc0a180e" +version = "15.0.4+0" [[deps.LZO_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -914,15 +924,9 @@ version = "1.3.0" [[deps.LatinHypercubeSampling]] deps = ["Random", "StableRNGs", "StatsBase", "Test"] -git-tree-sha1 = "42938ab65e9ed3c3029a8d2c58382ca75bdab243" +git-tree-sha1 = "825289d43c753c7f1bf9bed334c253e9913997f8" uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d" -version = "1.8.0" - -[[deps.LayoutPointers]] -deps = ["ArrayInterface", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface"] -git-tree-sha1 = "88b8f66b604da079a627b6fb2860d3704a6729a1" -uuid = "10f19ff3-798f-405d-979b-55457f8fc047" -version = "0.1.14" +version = "1.9.0" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -991,6 +995,12 @@ git-tree-sha1 = "7f3efec06033682db852f8b3bc3c1d2b0a0ab066" uuid = "38a345b3-de98-5d2b-a5d3-14cd9215e700" version = "2.36.0+0" +[[deps.LightXML]] +deps = ["Libdl", "XML2_jll"] +git-tree-sha1 = "e129d9391168c677cd4800f5c0abb1ed8cb3794f" +uuid = "9c8b4983-aa76-5018-a973-4c85ecc9e179" +version = "0.9.0" + [[deps.LineSearches]] deps = ["LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "Printf"] git-tree-sha1 = "7bbea35cec17305fc70a0e5b4641477dc0789d9d" @@ -1001,23 +1011,35 @@ version = "7.2.0" deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[[deps.LinearAlgebraX]] +deps = ["LinearAlgebra", "Mods", "Permutations", "Primes", "SimplePolynomials"] +git-tree-sha1 = "558a338f1eeabe933f9c2d4052aa7c2c707c3d52" +uuid = "9b3f67b0-2d00-526e-9884-9e4938f8fb88" +version = "0.1.12" + [[deps.LinearMaps]] -deps = ["ChainRulesCore", "LinearAlgebra", "SparseArrays", "Statistics"] -git-tree-sha1 = "4af48c3585177561e9f0d24eb9619ad3abf77cc7" +deps = ["LinearAlgebra"] +git-tree-sha1 = "6698ab5e662b47ffc63a82b2f43c1cee015cf80d" uuid = "7a12625a-238d-50fd-b39a-03d52299707e" -version = "3.10.0" +version = "3.11.0" +weakdeps = ["ChainRulesCore", "SparseArrays", "Statistics"] + + [deps.LinearMaps.extensions] + LinearMapsChainRulesCoreExt = "ChainRulesCore" + LinearMapsSparseArraysExt = "SparseArrays" + LinearMapsStatisticsExt = "Statistics" [[deps.Literate]] deps = ["Base64", "IOCapture", "JSON", "REPL"] -git-tree-sha1 = "1c4418beaa6664041e0f9b48f0710f57bff2fcbe" +git-tree-sha1 = "a1a0d4ff9f785a2184baca7a5c89e664f144143d" uuid = "98b081ad-f1c9-55d3-8b20-4c87d4299306" -version = "2.14.0" +version = "2.14.1" [[deps.LogExpFunctions]] deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "0a1b7c2863e44523180fdb3146534e265a91870b" +git-tree-sha1 = "c3ce8e7420b3a6e071e0fe4745f5d4300e37b13f" uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.23" +version = "0.3.24" [deps.LogExpFunctions.extensions] LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" @@ -1038,40 +1060,33 @@ git-tree-sha1 = "cedb76b37bc5a6c702ade66be44f831fa23c681e" uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" version = "1.0.0" -[[deps.LoopVectorization]] -deps = ["ArrayInterface", "ArrayInterfaceCore", "CPUSummary", "CloseOpenIntervals", "DocStringExtensions", "HostCPUFeatures", "IfElse", "LayoutPointers", "LinearAlgebra", "OffsetArrays", "PolyesterWeave", "PrecompileTools", "SIMDTypes", "SLEEFPirates", "Static", "StaticArrayInterface", "ThreadingUtilities", "UnPack", "VectorizationBase"] -git-tree-sha1 = "3bb62b5003bc7d2d49f26663484267dc49fa1bf5" -uuid = "bdcacae8-1622-11e9-2a5c-532679323890" -version = "0.12.159" -weakdeps = ["ChainRulesCore", "ForwardDiff", "SpecialFunctions"] - - [deps.LoopVectorization.extensions] - ForwardDiffExt = ["ChainRulesCore", "ForwardDiff"] - SpecialFunctionsExt = "SpecialFunctions" - [[deps.LossFunctions]] -deps = ["CategoricalArrays", "Markdown", "Statistics"] -git-tree-sha1 = "44a7bfeb7b5eb9386a62b9cccc6e21f406c15bea" +deps = ["Markdown", "Requires", "Statistics"] +git-tree-sha1 = "c2b72b61d2e3489b1f9cae3403226b21ec90c943" uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" -version = "0.10.0" +version = "0.11.0" +weakdeps = ["CategoricalArrays"] + + [deps.LossFunctions.extensions] + LossFunctionsCategoricalArraysExt = "CategoricalArrays" [[deps.MKL_jll]] deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "2ce8695e1e699b68702c03402672a69f54b8aca9" +git-tree-sha1 = "eb006abbd7041c28e0d16260e50a24f8f9104913" uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" -version = "2022.2.0+0" +version = "2023.2.0+0" [[deps.MLJ]] deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "ScientificTypes", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "80149328ca780b522b5a95e402450d10df7904f2" +git-tree-sha1 = "d26cd777c711c332019b39445823cbb1f6cdb7e5" uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" -version = "0.19.1" +version = "0.19.2" [[deps.MLJBase]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "4cc167b6c0a3ab25d7050e4ac38fe119e97cd1ab" +git-tree-sha1 = "2c9d6b9c627a80f6e6acbc6193026f455581fd04" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -version = "0.21.11" +version = "0.21.13" [[deps.MLJEnsembles]] deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"] @@ -1093,9 +1108,9 @@ version = "0.5.1" [[deps.MLJLinearModels]] deps = ["DocStringExtensions", "IterativeSolvers", "LinearAlgebra", "LinearMaps", "MLJModelInterface", "Optim", "Parameters"] -git-tree-sha1 = "c811b3877f1328179cef6662388d200c78b95c09" +git-tree-sha1 = "c92bf0ea37bf51e1ef0160069c572825819748b8" uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692" -version = "0.9.1" +version = "0.9.2" [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] @@ -1105,9 +1120,9 @@ version = "1.8.0" [[deps.MLJModels]] deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "21acf47dc53ccc3d68e38ac7629756cd09b599f5" +git-tree-sha1 = "2b49f04f70266a2b040eb46ece157c4f5c1b0c13" uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" -version = "0.16.6" +version = "0.16.10" [[deps.MLJTuning]] deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase"] @@ -1122,26 +1137,21 @@ uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" version = "0.5.10" [[deps.Makie]] -deps = ["Animations", "Base64", "ColorBrewer", "ColorSchemes", "ColorTypes", "Colors", "Contour", "Distributions", "DocStringExtensions", "Downloads", "FFMPEG", "FileIO", "FixedPointNumbers", "Formatting", "FreeType", "FreeTypeAbstraction", "GeometryBasics", "GridLayoutBase", "ImageIO", "InteractiveUtils", "IntervalSets", "Isoband", "KernelDensity", "LaTeXStrings", "LinearAlgebra", "MakieCore", "Markdown", "Match", "MathTeXEngine", "MiniQhull", "Observables", "OffsetArrays", "Packing", "PlotUtils", "PolygonOps", "Printf", "Random", "RelocatableFolders", "Setfield", "Showoff", "SignedDistanceFields", "SnoopPrecompile", "SparseArrays", "StableHashTraits", "Statistics", "StatsBase", "StatsFuns", "StructArrays", "TriplotBase", "UnicodeFun"] -git-tree-sha1 = "74657542dc85c3b72b8a5a9392d57713d8b7a999" +deps = ["Animations", "Base64", "ColorBrewer", "ColorSchemes", "ColorTypes", "Colors", "Contour", "DelaunayTriangulation", "Distributions", "DocStringExtensions", "Downloads", "FFMPEG", "FileIO", "FixedPointNumbers", "Formatting", "FreeType", "FreeTypeAbstraction", "GeometryBasics", "GridLayoutBase", "ImageIO", "InteractiveUtils", "IntervalSets", "Isoband", "KernelDensity", "LaTeXStrings", "LinearAlgebra", "MacroTools", "MakieCore", "Markdown", "Match", "MathTeXEngine", "Observables", "OffsetArrays", "Packing", "PlotUtils", "PolygonOps", "PrecompileTools", "Printf", "REPL", "Random", "RelocatableFolders", "Setfield", "ShaderAbstractions", "Showoff", "SignedDistanceFields", "SparseArrays", "StableHashTraits", "Statistics", "StatsBase", "StatsFuns", "StructArrays", "TriplotBase", "UnicodeFun"] +git-tree-sha1 = "729640354756782c89adba8857085a69e19be7ab" uuid = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" -version = "0.19.4" +version = "0.19.7" [[deps.MakieCore]] deps = ["Observables"] -git-tree-sha1 = "9926529455a331ed73c19ff06d16906737a876ed" +git-tree-sha1 = "87a85ff81583bd392642869557cb633532989517" uuid = "20f20a25-4f0e-4fdf-b5d1-57303727442b" -version = "0.6.3" - -[[deps.ManualMemory]] -git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" -uuid = "d125e4d3-2237-4719-b19c-fa641b8a4667" -version = "0.1.8" +version = "0.6.4" [[deps.MappedArrays]] -git-tree-sha1 = "e8b359ef06ec72e8c030463fe02efe5527ee5142" +git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" -version = "0.4.1" +version = "0.4.2" [[deps.Markdown]] deps = ["Base64"] @@ -1169,12 +1179,6 @@ deps = ["Artifacts", "Libdl"] uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" version = "2.28.2+0" -[[deps.MiniQhull]] -deps = ["QhullMiniWrapper_jll"] -git-tree-sha1 = "9dc837d180ee49eeb7c8b77bb1c860452634b0d1" -uuid = "978d7f02-9e05-4691-894f-ae31a51d76ca" -version = "0.4.0" - [[deps.Missings]] deps = ["DataAPI"] git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" @@ -1184,6 +1188,11 @@ version = "1.1.0" [[deps.Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +[[deps.Mods]] +git-tree-sha1 = "61be59e4daffff43a8cec04b5e0dc773cbb5db3a" +uuid = "7475f97c-0381-53b1-977b-4c60186c8d62" +version = "1.3.3" + [[deps.MosaicViews]] deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" @@ -1194,6 +1203,11 @@ version = "0.3.4" uuid = "14a3606d-f60d-562e-9121-12d972cd8159" version = "2022.10.11" +[[deps.Multisets]] +git-tree-sha1 = "8d852646862c96e226367ad10c8af56099b4047e" +uuid = "3b2b4ff1-bcff-5658-a3ee-dbcf1ce5ac09" +version = "0.4.4" + [[deps.NLSolversBase]] deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] git-tree-sha1 = "a0b464d183da839699f4c79e7606d9d186ec172c" @@ -1208,9 +1222,9 @@ version = "1.0.2" [[deps.NearestNeighborModels]] deps = ["Distances", "FillArrays", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "NearestNeighbors", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "c2179f9d8de066c481b889a1426068c5831bb10b" +git-tree-sha1 = "e411143a8362926e4284a54e745972e939fbab78" uuid = "636a865e-7cf4-491e-846c-de09b730eb36" -version = "0.2.2" +version = "0.2.3" [[deps.NearestNeighbors]] deps = ["Distances", "StaticArrays"] @@ -1220,9 +1234,9 @@ version = "0.4.13" [[deps.Netpbm]] deps = ["FileIO", "ImageCore", "ImageMetadata"] -git-tree-sha1 = "5ae7ca23e13855b3aba94550f26146c01d259267" +git-tree-sha1 = "d92b107dbb887293622df7697a2223f9f8176fcd" uuid = "f09324ee-3d7c-5217-9330-fc30815ba969" -version = "1.1.0" +version = "1.1.1" [[deps.NetworkLayout]] deps = ["GeometryBasics", "LinearAlgebra", "Random", "Requires", "SparseArrays", "StaticArrays"] @@ -1241,9 +1255,9 @@ version = "0.5.4" [[deps.OffsetArrays]] deps = ["Adapt"] -git-tree-sha1 = "82d7c9e310fe55aa54996e6f7f94674e2a38fcb4" +git-tree-sha1 = "2ac17d29c523ce1cd38e27785a7d23024853a4bb" uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -version = "1.12.9" +version = "1.12.10" [[deps.Ogg_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1286,10 +1300,10 @@ uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" version = "1.4.1" [[deps.OpenSSL_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9ff31d101d987eb9d66bd8b176ac7c277beccd09" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "e78db7bd5c26fc5a6911b50a47ee302219157ea8" uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "1.1.20+0" +version = "3.0.10+0" [[deps.OpenSpecFun_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] @@ -1299,9 +1313,9 @@ version = "0.5.5+0" [[deps.Optim]] deps = ["Compat", "FillArrays", "ForwardDiff", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"] -git-tree-sha1 = "a89b11f0f354f06099e4001c151dffad7ebab015" +git-tree-sha1 = "e3a6546c1577bfd701771b477b794a52949e7594" uuid = "429524aa-4258-5aef-a3af-852621145aeb" -version = "1.7.5" +version = "1.7.6" [[deps.Opus_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] @@ -1310,9 +1324,9 @@ uuid = "91d4177d-7536-5919-b921-800302f37372" version = "1.3.2+0" [[deps.OrderedCollections]] -git-tree-sha1 = "d321bf2de576bf25ec4d3e4360faca399afca282" +git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.0" +version = "1.6.2" [[deps.PCRE2_jll]] deps = ["Artifacts", "Libdl"] @@ -1327,9 +1341,9 @@ version = "0.11.17" [[deps.PNGFiles]] deps = ["Base64", "CEnum", "ImageCore", "IndirectArrays", "OffsetArrays", "libpng_jll"] -git-tree-sha1 = "f809158b27eba0c18c269cf2a2be6ed751d3e81d" +git-tree-sha1 = "9b02b27ac477cad98114584ff964e3052f656a0f" uuid = "f57f5aa1-a3ce-4bc8-8ab9-96f992907883" -version = "0.3.17" +version = "0.4.0" [[deps.Packing]] deps = ["GeometryBasics"] @@ -1356,16 +1370,22 @@ uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" version = "0.12.3" [[deps.Parsers]] -deps = ["Dates", "SnoopPrecompile"] -git-tree-sha1 = "478ac6c952fddd4399e71d4779797c538d0ff2bf" +deps = ["Dates", "PrecompileTools", "UUIDs"] +git-tree-sha1 = "716e24b21538abc91f6205fd1d8363f39b442851" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.5.8" +version = "2.7.2" + +[[deps.Permutations]] +deps = ["Combinatorics", "LinearAlgebra", "Random"] +git-tree-sha1 = "6e6cab1c54ae2382bcc48866b91cf949cea703a1" +uuid = "2ae35dd2-176d-5d53-8349-f30d82d94d4f" +version = "0.4.16" [[deps.Pixman_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "b4f5d02549a10e20780a24fce72bea96b6329e29" +deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "Libdl"] +git-tree-sha1 = "64779bc4c9784fee475689a1752ef4d5747c5e87" uuid = "30392449-352a-5448-841d-b1acce4e97dc" -version = "0.40.1+0" +version = "0.42.2+0" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] @@ -1384,17 +1404,27 @@ git-tree-sha1 = "f92e1315dadf8c46561fb9396e525f7200cdc227" uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" version = "1.3.5" -[[deps.PolyesterWeave]] -deps = ["BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "Static", "ThreadingUtilities"] -git-tree-sha1 = "240d7170f5ffdb285f9427b92333c3463bf65bf6" -uuid = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad" -version = "0.2.1" - [[deps.PolygonOps]] git-tree-sha1 = "77b3d3605fc1cd0b42d95eba87dfcd2bf67d5ff6" uuid = "647866c9-e3ac-4575-94e7-e3d426903924" version = "0.1.2" +[[deps.Polynomials]] +deps = ["LinearAlgebra", "RecipesBase"] +git-tree-sha1 = "3aa2bb4982e575acd7583f01531f241af077b163" +uuid = "f27b6e38-b328-58d1-80ce-0feddd5e7a45" +version = "3.2.13" + + [deps.Polynomials.extensions] + PolynomialsChainRulesCoreExt = "ChainRulesCore" + PolynomialsMakieCoreExt = "MakieCore" + PolynomialsMutableArithmeticsExt = "MutableArithmetics" + + [deps.Polynomials.weakdeps] + ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" + MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b" + MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" + [[deps.PooledArrays]] deps = ["DataAPI", "Future"] git-tree-sha1 = "a6062fe4063cdafe78f4a0a81cfffb89721b30e7" @@ -1409,9 +1439,9 @@ version = "0.2.4" [[deps.PrecompileTools]] deps = ["Preferences"] -git-tree-sha1 = "259e206946c293698122f63e2b513a7c99a244e8" +git-tree-sha1 = "9673d39decc5feece56ef3940e5dafba15ba0f81" uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.1.1" +version = "1.1.2" [[deps.Preferences]] deps = ["TOML"] @@ -1420,15 +1450,21 @@ uuid = "21216c6a-2e73-6563-6e65-726566657250" version = "1.4.0" [[deps.PrettyPrinting]] -git-tree-sha1 = "4be53d093e9e37772cc89e1009e8f6ad10c4681b" +git-tree-sha1 = "22a601b04a154ca38867b991d5017469dc75f2db" uuid = "54e16d92-306c-5ea0-a30b-337be88ac337" -version = "0.4.0" +version = "0.4.1" [[deps.PrettyTables]] -deps = ["Crayons", "Formatting", "LaTeXStrings", "Markdown", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "213579618ec1f42dea7dd637a42785a608b1ea9c" +deps = ["Crayons", "LaTeXStrings", "Markdown", "Printf", "Reexport", "StringManipulation", "Tables"] +git-tree-sha1 = "ee094908d720185ddbdc58dbe0c1cbe35453ec7a" uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.2.4" +version = "2.2.7" + +[[deps.Primes]] +deps = ["IntegerMathUtils"] +git-tree-sha1 = "4c9f306e5d6603ae203c2000dd460d81a5251489" +uuid = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae" +version = "0.5.4" [[deps.Printf]] deps = ["Unicode"] @@ -1446,18 +1482,6 @@ git-tree-sha1 = "18e8f4d1426e965c7b532ddd260599e1510d26ce" uuid = "4b34888f-f399-49d4-9bb3-47ed5cae4e65" version = "1.0.0" -[[deps.QhullMiniWrapper_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Qhull_jll"] -git-tree-sha1 = "607cf73c03f8a9f83b36db0b86a3a9c14179621f" -uuid = "460c41e3-6112-5d7f-b78c-b6823adb3f2d" -version = "1.0.0+1" - -[[deps.Qhull_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "238dd7e2cc577281976b9681702174850f8d4cbc" -uuid = "784f63db-0788-585a-bace-daefebcd302b" -version = "8.0.1001+0" - [[deps.QuadGK]] deps = ["DataStructures", "LinearAlgebra"] git-tree-sha1 = "6ec7ac8412e83d57e313393220879ede1740f9ee" @@ -1491,20 +1515,14 @@ version = "0.3.2" [[deps.Ratios]] deps = ["Requires"] -git-tree-sha1 = "6d7bb727e76147ba18eed998700998e17b8e4911" +git-tree-sha1 = "1342a47bf3260ee108163042310d26f2be5ec90b" uuid = "c84ed2f1-dad5-54f0-aa8e-dbefe2724439" -version = "0.4.4" +version = "0.4.5" weakdeps = ["FixedPointNumbers"] [deps.Ratios.extensions] RatiosFixedPointNumbersExt = "FixedPointNumbers" -[[deps.RealDot]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "9f0a1b71baaf7650f4fa8a1d168c7fb6ee41f0c9" -uuid = "c1ae055f-0cd5-4b69-90a6-9a35b1a98df9" -version = "0.1.0" - [[deps.RecipesBase]] deps = ["PrecompileTools"] git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" @@ -1528,6 +1546,12 @@ git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" uuid = "ae029012-a4dd-5104-9daa-d747884805df" version = "1.3.0" +[[deps.RingLists]] +deps = ["Random"] +git-tree-sha1 = "9712ebc42e91850f35272b48eb840e60c0270ec0" +uuid = "286e9d63-9694-5540-9e3c-4e6708fa07b2" +version = "0.2.7" + [[deps.Rmath]] deps = ["Random", "Rmath_jll"] git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" @@ -1540,41 +1564,15 @@ git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da" uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" version = "0.4.0+0" -[[deps.Roots]] -deps = ["ChainRulesCore", "CommonSolve", "Printf", "Setfield"] -git-tree-sha1 = "e19c09f5cc868785766f86435ba40576cf751257" -uuid = "f2b01f46-fcfa-551c-844a-d8ac1e96c665" -version = "2.0.14" - - [deps.Roots.extensions] - RootsForwardDiffExt = "ForwardDiff" - RootsIntervalRootFindingExt = "IntervalRootFinding" - - [deps.Roots.weakdeps] - ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" - IntervalRootFinding = "d2bf35a9-74e0-55ec-b149-d360ff49b807" +[[deps.RoundingEmulator]] +git-tree-sha1 = "40b9edad2e5287e05bd413a38f61a8ff55b9557b" +uuid = "5eaf0fd0-dfba-4ccb-bf02-d820a40db705" +version = "0.2.1" [[deps.SHA]] uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" -[[deps.SIMD]] -deps = ["SnoopPrecompile"] -git-tree-sha1 = "8b20084a97b004588125caebf418d8cab9e393d1" -uuid = "fdea26ae-647d-5447-a871-4b548cad5224" -version = "3.4.4" - -[[deps.SIMDTypes]] -git-tree-sha1 = "330289636fb8107c5f32088d2741e9fd7a061a5c" -uuid = "94e857df-77ce-4151-89e5-788b33177be4" -version = "0.1.0" - -[[deps.SLEEFPirates]] -deps = ["IfElse", "Static", "VectorizationBase"] -git-tree-sha1 = "cda0aece8080e992f6370491b08ef3909d1c04e7" -uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa" -version = "0.6.38" - [[deps.SQLite]] deps = ["DBInterface", "Random", "SQLite_jll", "Serialization", "Tables", "WeakRefStrings"] git-tree-sha1 = "eb9a473c9b191ced349d04efa612ec9f39c087ea" @@ -1583,15 +1581,9 @@ version = "1.6.0" [[deps.SQLite_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "54d66b0f69f4578f4988fc08d579783fcdcd764f" +git-tree-sha1 = "4619dd3363610d94fb42a95a6dc35b526a26d0ef" uuid = "76ed43ae-9a5d-5a62-8c75-30186b810ce8" -version = "3.41.0+0" - -[[deps.ScanByte]] -deps = ["Libdl", "SIMD"] -git-tree-sha1 = "2436b15f376005e8790e318329560dcc67188e84" -uuid = "7b38b023-a4d7-4c5e-8d43-3f3097f304eb" -version = "0.3.3" +version = "3.42.0+0" [[deps.ScientificTypes]] deps = ["CategoricalArrays", "ColorTypes", "Dates", "Distributions", "PrettyTables", "Reexport", "ScientificTypesBase", "StatisticalTraits", "Tables"] @@ -1612,19 +1604,30 @@ version = "1.2.0" [[deps.SentinelArrays]] deps = ["Dates", "Random"] -git-tree-sha1 = "77d3c4726515dca71f6d80fbb5e251088defe305" +git-tree-sha1 = "04bdff0b09c65ff3e06a05e3eb7b120223da3d39" uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.3.18" +version = "1.4.0" [[deps.Serialization]] uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +[[deps.SetRounding]] +git-tree-sha1 = "d7a25e439d07a17b7cdf97eecee504c50fedf5f6" +uuid = "3cc68bcd-71a2-5612-b932-767ffbe40ab0" +version = "0.2.1" + [[deps.Setfield]] deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" version = "1.1.1" +[[deps.ShaderAbstractions]] +deps = ["ColorTypes", "FixedPointNumbers", "GeometryBasics", "LinearAlgebra", "Observables", "StaticArrays", "StructArrays", "Tables"] +git-tree-sha1 = "0d15c3e7b2003f4451714f08ffec2b77badc2dc4" +uuid = "65257c39-d410-5151-9873-9b3e5be5013e" +version = "0.3.0" + [[deps.SharedArrays]] deps = ["Distributed", "Mmap", "Random", "Serialization"] uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" @@ -1651,6 +1654,30 @@ git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" version = "1.1.0" +[[deps.SimpleGraphs]] +deps = ["AbstractLattices", "Combinatorics", "DataStructures", "IterTools", "LightXML", "LinearAlgebra", "LinearAlgebraX", "Optim", "Primes", "Random", "RingLists", "SimplePartitions", "SimplePolynomials", "SimpleRandom", "SparseArrays", "Statistics"] +git-tree-sha1 = "b608903049d11cc557c45e03b3a53e9260579c19" +uuid = "55797a34-41de-5266-9ec1-32ac4eb504d3" +version = "0.8.4" + +[[deps.SimplePartitions]] +deps = ["AbstractLattices", "DataStructures", "Permutations"] +git-tree-sha1 = "dcc02923a53f316ab97da8ef3136e80b4543dbf1" +uuid = "ec83eff0-a5b5-5643-ae32-5cbf6eedec9d" +version = "0.3.0" + +[[deps.SimplePolynomials]] +deps = ["Mods", "Multisets", "Polynomials", "Primes"] +git-tree-sha1 = "9f1b1f47279018b35316c62e829af1f3f6725a47" +uuid = "cc47b68c-3164-5771-a705-2bc0097375a0" +version = "0.2.13" + +[[deps.SimpleRandom]] +deps = ["Distributions", "LinearAlgebra", "Random"] +git-tree-sha1 = "3a6fb395e37afab81aeea85bae48a4db5cd7244a" +uuid = "a6525b86-64cd-54fa-8f65-62fc48bdc0e8" +version = "0.3.1" + [[deps.SimpleTraits]] deps = ["InteractiveUtils", "MacroTools"] git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" @@ -1659,24 +1686,18 @@ version = "0.9.4" [[deps.Sixel]] deps = ["Dates", "FileIO", "ImageCore", "IndirectArrays", "OffsetArrays", "REPL", "libsixel_jll"] -git-tree-sha1 = "8fb59825be681d451c246a795117f317ecbcaa28" +git-tree-sha1 = "2da10356e31327c7096832eb9cd86307a50b1eb6" uuid = "45858cf5-a6b0-47a3-bbea-62219f50df47" -version = "0.1.2" - -[[deps.SnoopPrecompile]] -deps = ["Preferences"] -git-tree-sha1 = "e760a70afdcd461cf01a575947738d359234665c" -uuid = "66db9d55-30c0-4569-8b51-7e840670fc0c" -version = "1.0.3" +version = "0.1.3" [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" [[deps.SortingAlgorithms]] deps = ["DataStructures"] -git-tree-sha1 = "a4ada03f999bd01b3a25dcaa30b2d929fe537e00" +git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.1.0" +version = "1.1.1" [[deps.SparseArrays]] deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] @@ -1684,9 +1705,9 @@ uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [[deps.SpecialFunctions]] deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "ef28127915f4229c971eb43f3fc075dd3fe91880" +git-tree-sha1 = "7beb031cf8145577fbccacd94b8a8f4ce78428d3" uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.2.0" +version = "2.3.0" weakdeps = ["ChainRulesCore"] [deps.SpecialFunctions.extensions] @@ -1710,33 +1731,20 @@ git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" version = "0.1.1" -[[deps.Static]] -deps = ["IfElse"] -git-tree-sha1 = "08be5ee09a7632c32695d954a602df96a877bf0d" -uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" -version = "0.8.6" - -[[deps.StaticArrayInterface]] -deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "Requires", "SnoopPrecompile", "SparseArrays", "Static", "SuiteSparse"] -git-tree-sha1 = "33040351d2403b84afce74dae2e22d3f5b18edcb" -uuid = "0d7ed370-da01-4f52-bd93-41d350b8b718" -version = "1.4.0" -weakdeps = ["OffsetArrays", "StaticArrays"] - - [deps.StaticArrayInterface.extensions] - StaticArrayInterfaceOffsetArraysExt = "OffsetArrays" - StaticArrayInterfaceStaticArraysExt = "StaticArrays" - [[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore", "Statistics"] -git-tree-sha1 = "c262c8e978048c2b095be1672c9bee55b4619521" +deps = ["LinearAlgebra", "Random", "StaticArraysCore"] +git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881" uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.5.24" +version = "1.6.2" +weakdeps = ["Statistics"] + + [deps.StaticArrays.extensions] + StaticArraysStatisticsExt = "Statistics" [[deps.StaticArraysCore]] -git-tree-sha1 = "6b7ba252635a5eff6a0b0664a41ee140a1c9e72a" +git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.0" +version = "1.4.2" [[deps.StatisticalTraits]] deps = ["ScientificTypesBase"] @@ -1757,9 +1765,9 @@ version = "1.6.0" [[deps.StatsBase]] deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "d1bf48bfcc554a3761a133fe3a9bb01488e06916" +git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4" uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.33.21" +version = "0.34.0" [[deps.StatsFuns]] deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] @@ -1781,12 +1789,6 @@ git-tree-sha1 = "8cc7a5385ecaa420f0b3426f9b0135d0df0638ed" uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d" version = "0.7.2" -[[deps.StringEncodings]] -deps = ["Libiconv_jll"] -git-tree-sha1 = "33c0da881af3248dafefb939a21694b97cfece76" -uuid = "69024149-9ee7-55f6-a4c4-859efe599b68" -version = "0.3.6" - [[deps.StringManipulation]] git-tree-sha1 = "46da2434b41f41ac3594ee9816ce5541c6096123" uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" @@ -1807,25 +1809,11 @@ deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" version = "5.10.1+6" -[[deps.TMLE]] -deps = ["AbstractDifferentiation", "CategoricalArrays", "Configurations", "Distributions", "GLM", "HypothesisTests", "LogExpFunctions", "MLJBase", "MLJGLMInterface", "MLJModels", "Missings", "PrettyTables", "Statistics", "TableOperations", "Tables", "YAML", "Zygote"] -git-tree-sha1 = "93a4acacae3a1dc0a598703e3c7ed46766614da0" -repo-rev = "main" -repo-url = ".." -uuid = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf" -version = "0.11.1" - [[deps.TOML]] deps = ["Dates"] uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" version = "1.0.3" -[[deps.TableOperations]] -deps = ["SentinelArrays", "Tables", "Test"] -git-tree-sha1 = "e383c87cf2a1dc41fa30c093b2a19877c83e1bc1" -uuid = "ab02a1b2-a7df-11e8-156e-fb1833f50b87" -version = "1.2.0" - [[deps.TableTraits]] deps = ["IteratorInterfaceExtensions"] git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" @@ -1853,12 +1841,6 @@ version = "0.1.1" deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -[[deps.ThreadingUtilities]] -deps = ["ManualMemory"] -git-tree-sha1 = "c97f60dd4f2331e1a495527f80d242501d2f9865" -uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" -version = "0.5.1" - [[deps.TiffImages]] deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "UUIDs"] git-tree-sha1 = "8621f5c499a8aa4aa970b1ae381aae0ef1576966" @@ -1888,9 +1870,9 @@ uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" version = "1.3.0" [[deps.URIs]] -git-tree-sha1 = "074f993b0ca030848b897beff716d93aca60f06a" +git-tree-sha1 = "b7a5e99f24892b6824a954199a45e9ffcc1c70f0" uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.4.2" +version = "1.5.0" [[deps.UUIDs]] deps = ["Random", "SHA"] @@ -1917,15 +1899,9 @@ version = "0.2.1" [[deps.UnsafeAtomicsLLVM]] deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "ea37e6066bf194ab78f4e747f5245261f17a7175" +git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.2" - -[[deps.VectorizationBase]] -deps = ["ArrayInterface", "CPUSummary", "HostCPUFeatures", "IfElse", "LayoutPointers", "Libdl", "LinearAlgebra", "SIMDTypes", "Static", "StaticArrayInterface"] -git-tree-sha1 = "b182207d4af54ac64cbc71797765068fdeff475d" -uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" -version = "0.21.64" +version = "0.1.3" [[deps.WeakRefStrings]] deps = ["DataAPI", "InlineStrings", "Parsers"] @@ -1939,6 +1915,11 @@ git-tree-sha1 = "de67fa59e33ad156a590055375a30b23c40299d3" uuid = "efce3f68-66dc-5838-9240-27a6d6f5f9b6" version = "0.5.5" +[[deps.WorkerUtilities]] +git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" +uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" +version = "1.6.1" + [[deps.XML2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "Zlib_jll"] git-tree-sha1 = "93c41695bc1c08c46c5899f4fe06d6ead504bb73" @@ -1952,22 +1933,22 @@ uuid = "aed1982a-8fda-507f-9586-7b0439959a61" version = "1.1.34+0" [[deps.Xorg_libX11_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libxcb_jll", "Xorg_xtrans_jll"] -git-tree-sha1 = "5be649d550f3f4b95308bf0183b82e2582876527" +deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxcb_jll", "Xorg_xtrans_jll"] +git-tree-sha1 = "afead5aba5aa507ad5a3bf01f58f82c8d1403495" uuid = "4f6342f7-b3d2-589e-9d20-edeb45f2b2bc" -version = "1.6.9+4" +version = "1.8.6+0" [[deps.Xorg_libXau_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "4e490d5c960c314f33885790ed410ff3a94ce67e" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "6035850dcc70518ca32f012e46015b9beeda49d8" uuid = "0c0b7dd1-d40b-584c-a123-a41640f87eec" -version = "1.0.9+4" +version = "1.0.11+0" [[deps.Xorg_libXdmcp_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "4fe47bd2247248125c428978740e18a681372dd4" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "34d526d318358a859d7de23da945578e8e8727b7" uuid = "a3789734-cfe1-5b06-b2d0-1dd0d9d62d05" -version = "1.1.3+4" +version = "1.1.4+0" [[deps.Xorg_libXext_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] @@ -1982,56 +1963,28 @@ uuid = "ea2f1a96-1ddc-540d-b46f-429655e07cfa" version = "0.9.10+4" [[deps.Xorg_libpthread_stubs_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "6783737e45d3c59a4a4c4091f5f88cdcf0908cbb" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "8fdda4c692503d44d04a0603d9ac0982054635f9" uuid = "14d82f49-176c-5ed1-bb49-ad3f5cbd8c74" -version = "0.1.0+3" +version = "0.1.1+0" [[deps.Xorg_libxcb_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "XSLT_jll", "Xorg_libXau_jll", "Xorg_libXdmcp_jll", "Xorg_libpthread_stubs_jll"] -git-tree-sha1 = "daf17f441228e7a3833846cd048892861cff16d6" +deps = ["Artifacts", "JLLWrappers", "Libdl", "XSLT_jll", "Xorg_libXau_jll", "Xorg_libXdmcp_jll", "Xorg_libpthread_stubs_jll"] +git-tree-sha1 = "b4bfde5d5b652e22b9c790ad00af08b6d042b97d" uuid = "c7cfdc94-dc32-55de-ac96-5a1b8d977c5b" -version = "1.13.0+3" +version = "1.15.0+0" [[deps.Xorg_xtrans_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "79c31e7844f6ecf779705fbc12146eb190b7d845" +deps = ["Artifacts", "JLLWrappers", "Libdl"] +git-tree-sha1 = "e92a1a012a10506618f10b7047e478403a046c77" uuid = "c5fb5394-a638-5e4d-96e5-b29de1b5cf10" -version = "1.4.0+3" - -[[deps.YAML]] -deps = ["Base64", "Dates", "Printf", "StringEncodings"] -git-tree-sha1 = "dbc7f1c0012a69486af79c8bcdb31be820670ba2" -uuid = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" -version = "0.4.8" +version = "1.5.0+0" [[deps.Zlib_jll]] deps = ["Libdl"] uuid = "83775a58-1f1d-513f-b197-d71354ab007a" version = "1.2.13+0" -[[deps.Zygote]] -deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "GPUArrays", "GPUArraysCore", "IRTools", "InteractiveUtils", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NaNMath", "Random", "Requires", "SnoopPrecompile", "SparseArrays", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "987ae5554ca90e837594a0f30325eeb5e7303d1e" -uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.60" - - [deps.Zygote.extensions] - ZygoteColorsExt = "Colors" - ZygoteDistancesExt = "Distances" - ZygoteTrackerExt = "Tracker" - - [deps.Zygote.weakdeps] - Colors = "5ae59095-9a9b-59fe-a467-6f913c188581" - Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - -[[deps.ZygoteRules]] -deps = ["ChainRulesCore", "MacroTools"] -git-tree-sha1 = "977aed5d006b840e2e40c0b48984f7463109046d" -uuid = "700de1a5-db45-46bc-99cf-38207098b444" -version = "0.2.3" - [[deps.isoband_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "51b5eeb3f98367157a7a12a1fb0aa5328946c03c" @@ -2053,7 +2006,7 @@ version = "0.15.1+0" [[deps.libblastrampoline_jll]] deps = ["Artifacts", "Libdl"] uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.7.0+0" +version = "5.8.0+0" [[deps.libfdk_aac_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] diff --git a/docs/Project.toml b/docs/Project.toml index 4ddd1047..18f38b85 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,4 +1,5 @@ [deps] +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" @@ -15,7 +16,6 @@ NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SQLite = "0aa819cd-b072-5ff4-a722-6bc24af294d9" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -TMLE = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" @@ -24,10 +24,11 @@ CairoMakie = "0.10.4" DataFrames = "1.5" Distributions = "0.25" Documenter = "0.27" -EvoTrees = "0.14" +EvoTrees = "0.15" Literate = "2.13" MLJ = "0.19.1" MLJLinearModels = "0.9.1" NearestNeighborModels = "0.2" SQLite = "1.4" +CSV = "0.10.11" StableRNGs = "1.0" diff --git a/examples/double_robustness.jl b/examples/double_robustness.jl index d2a44323..8b166cc1 100644 --- a/examples/double_robustness.jl +++ b/examples/double_robustness.jl @@ -25,7 +25,6 @@ using LogExpFunctions using MLJGLMInterface using DataFrames using CairoMakie -using CategoricalArrays μY(T, W) = exp.(1 .- 10T .+ 1W) diff --git a/examples/super_learning.jl b/examples/super_learning.jl index 31cb2618..75d410ff 100644 --- a/examples/super_learning.jl +++ b/examples/super_learning.jl @@ -3,7 +3,7 @@ ## What this tutorial is about -Super Learning, also known as Stacking, is an emsemble technique that was first introduced by Wolpert in 1992. +Super Learning, also known as Stacking, is an ensemble technique that was first introduced by Wolpert in 1992. Instead of selecting a model based on cross-validation performance, models are combined by a meta-learner to minimize the cross-validation error. It has also been shown by van der Laan et al. that the resulting Super Learner will perform at least as well as its best performing submodel. @@ -11,91 +11,46 @@ Super Learner will perform at least as well as its best performing submodel. Why is it important for Targeted Learning? The short answer is that the consistency (convergence in probability) of the targeted -estimator depends on the consistency of at least one of the nuisance estimands: ``Q_0`` or ``G_0`` (see [Mathematical setting](@ref) ). +estimator depends on the consistency of at least one of the nuisance estimands: ``Q_0`` or ``G_0``. By only using unrealistic models like linear models, we have little chance of satisfying the above criterion. Super Learning is a data driven way to leverage a diverse set of models and build the best performing estimator for both ``Q_0`` or ``G_0``. -In the following, we investigate the benefits of Super Learning for the estimation of [The Average Treatment Effect](@ref). +In the following, we investigate the benefits of Super Learning for the estimation of the Average Treatment Effect. ## The dataset -To demonstrate the benefits of Super Learning, we need to simulate a dataset more stimulating -than a simple linear regression, otherwise there would be nothing to do. Here is what I could -come up with: +For this example we will use the following perinatal dataset. The (haz01, parity01) are converted to +categorical values. =# - +using CSV using DataFrames using TMLE -using Random -using Distributions -using StableRNGs using CairoMakie -using CategoricalArrays using MLJ -using TMLE -using LogExpFunctions - -μY(T, W₁, W₂) = sin.(10T.*W₁).*exp.(-1 .+ 2W₁.*W₂ .- T.*W₂) .- cos.(10T.*W₂).*log.(2 .- W₁.*W₂) -μT(W₁, W₂) = logistic.(10sin.(W₁) .- 1.5W₂) - -function hard_problem(;n=1000, doT=nothing) - rng = StableRNG(123) - W₁ = rand(rng, Uniform(), n) - W₂ = rand(rng, Uniform(), n) - T = doT === nothing ? - rand(rng, Uniform(), n) .< μT(W₁, W₂) : - repeat([doT], n) - Y = μY(T, W₁, W₂) .+ rand(rng, Normal(0, 0.1), n) - return DataFrame(W₁=W₁, W₂=W₂, T=categorical(T), Y=Y) -end -N = 10000 -dataset = hard_problem(;n=N) +dataset = CSV.read( + joinpath(pkgdir(TMLE), "test", "data", "perinatal.csv"), + DataFrame, + select=[:haz01, :parity01, :apgar1, :apgar5, :gagebrth, :mage, :meducyrs, :sexn], + types=Float64 +) +dataset.haz01 = categorical(dataset.haz01) +dataset.parity01 = categorical(dataset.parity01) nothing # hide #= -It may seem difficult to understand it but we can still have a look at the function ``\mathbf{E}[Y|T,W₁, W₂]``. -=# - - -function plot_μY(;npoints=100) - W₁ = LinRange(0, 1, npoints) - W₂ = LinRange(0, 1, npoints) - y₁ = [μY(1, w₁, w₂) for w₁ in W₁, w₂ in W₂] - y₀ = [μY(0, w₁, w₂) for w₁ in W₁, w₂ in W₂] - t = [μT(w₁, w₂) for w₁ in W₁, w₂ in W₂] - fig = Figure(backgroundcolor = RGBf(0.98, 0.98, 0.98), resolution = (1000, 700)) - ax₁ = Axis3(fig[1, 1], aspect=:data, xlabel="W₁", ylabel="W₂", zlabel="μY") - Label(fig[1, 1, Top()], "μY(W₁, W₂, T=1)", fontsize = 30, tellwidth=false, tellheight=false) - surface!(ax₁, W₁, W₂, y₁) - ax₀ = Axis3(fig[1, 2], aspect=:data, xlabel="W₁", ylabel="W₂", zlabel="μY") - Label(fig[1, 2, Top()], "μY(W₁, W₂, T=0)", fontsize = 30, tellwidth=false, tellheight=false) - surface!(ax₀, W₁, W₂, y₀) - axₜ = Axis3(fig[2, :], aspect=:data, xlabel="W₁", ylabel="W₂", zlabel="μT") - Label(fig[2, :, Top()], "μT(W₁, W₂)", fontsize = 30, tellwidth=false, tellheight=false) - surface!(axₜ, W₁, W₂, t) - return fig -end -plot_μY(;npoints=100) +We will also assume the following causal model: -#= -Also, even though the analytical solution for the ATE may be intractable, since we known the -data generating process, we can approximate it by sampling. =# -function approximate_ate(;n=1000) - dataset_1 = hard_problem(;n=n, doT = 1) - dataset_0 = hard_problem(;n=n, doT = 0) - return mean(dataset_1.Y) - mean(dataset_0.Y) -end - -ψ₀ = approximate_ate(;n=1000) +scm = SCM( + SE(:haz01, [:parity01, :apgar1, :apgar5, :gagebrth, :mage, :meducyrs, :sexn]), + SE(:parity01, [:apgar1, :apgar5, :gagebrth, :mage, :meducyrs, :sexn]) +) #= -Now that the data has been generated and we know the solution to our problem, we can dive into the estimation -part. ## Defining a Super Learner in MLJ @@ -108,10 +63,10 @@ function. The three most important type of arguments for a Stack are: One important point is that MLJ does not provide any model by itself, those have to be loaded from external compatible libraries. You can search for available models that match your data. -for instance, for ``G_0 = P(T| W_1, W_2)``, we need classification models and we can see there are quire a bunch of them: +In our case, for both ``G_0`` and ``Q_0`` we need classification models and we can see there are quire a few of them: =# -G_available_models = models(matching(dataset[!, [:W₁, :W₂]], dataset.T)) +G_available_models = models(matching(dataset[!, parents(scm.parity01)], dataset.parity01)) #= !!! note "Stack limitations" @@ -127,7 +82,7 @@ using MLJLinearModels using MLJModels using NearestNeighborModels -function Gmodels() +function superlearner_models() lambdas = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 0, 1., 10., 100.] logistic_models = [LogisticClassifier(lambda=l) for l in lambdas] logistic_models = NamedTuple{Tuple(Symbol("lr_$i") for i in eachindex(lambdas))}(logistic_models) @@ -139,72 +94,45 @@ function Gmodels() return merge(logistic_models, evo_trees, knns) end -G_super_learner = Stack(; +superlearner = Stack(; metalearner = LogisticClassifier(lambda=0), resampling = StratifiedCV(nfolds=3), measure = log_loss, - Gmodels()... + superlearner_models()... ) - nothing # hide #= -And now do the same for ``Q_0`` -=# -function Qmodels() - lambdas = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 0, 1., 10., 100.] - linear_models = [RidgeRegressor(lambda=l) for l in lambdas] - linear_models = NamedTuple{Tuple(Symbol("lr_$i") for i in eachindex(lambdas))}(linear_models) - evo_trees = [EvoTreeRegressor(lambda=l) for l in lambdas] - evo_trees = NamedTuple{Tuple(Symbol("tree_$i") for i in eachindex(lambdas))}(evo_trees) - Ks = [5, 10, 50, 100] - knns = [KNNRegressor(K=k) for k in Ks] - knns = NamedTuple{Tuple(Symbol("knn_$i") for i in eachindex(Ks))}(knns) - return merge(linear_models, evo_trees, knns) -end +and assign those models to the `SCM`: -Q_super_learner = Stack(; - metalearner = LinearRegressor(fit_intercept=false), - resampling=CV(nfolds=3), - Qmodels()... - ) +=# -nothing # hide +setmodel!(scm.haz01, with_encoder(superlearner)) +setmodel!(scm.parity01, superlearner) #= + ## Targeted estimation -Let us move to the targeted estimation step itself. We define the target estimand (the ATE) and the nuisance estimands specification: +Let us move to the targeted estimation step itself. We define the target estimand (the ATE): =# Ψ = ATE( - outcome=:Y, - treatment=(T=(case=true, control=false),), - confounders=[:W₁, :W₂], - outcome_model=Q_super_learner, - treatment_model=G_super_learner + scm, + outcome=:haz01, + treatment=(parity01=(case=true, control=false),), ) nothing # hide #= -Finally run the TMLE procedure and check the result +Finally run the TMLE procedure: =# tmle_result, _ = tmle!(Ψ, dataset) -test_result = OneSampleTTest(tmle(tmle_result), ψ₀) - -#= -Now, what if we had used linear models only instead of the Super Learner? This is easy to check -=# -setmodel!(Ψ.scm.Y, LinearRegressor()) -setmodel!(Ψ.scm.T, LogisticClassifier()) - -tmle_result_linear, cache = tmle!(Ψ, dataset) - -test_result_linear = OneSampleTTest(tmle(tmle_result_linear), ψ₀) +tmle_result # using Test # hide diff --git a/ext/GraphsTMLEExt.jl b/ext/GraphsTMLEExt.jl deleted file mode 100644 index c3e68b73..00000000 --- a/ext/GraphsTMLEExt.jl +++ /dev/null @@ -1,48 +0,0 @@ -module GraphsTMLEExt - -using TMLE -using Graphs -using CairoMakie -using GraphMakie - -function edges_list_and_nodes_mapping(scm::SCM) - edge_list = [] - node_ids = Dict() - node_id = 1 - for (outcome, eq) in equations(scm) - if outcome ∉ keys(node_ids) - push!(node_ids, outcome => node_id) - node_id += 1 - end - outcome_id = node_ids[outcome] - for parent in TMLE.parents(eq) - if parent ∉ keys(node_ids) - push!(node_ids, parent => node_id) - node_id += 1 - end - parent_id = node_ids[parent] - push!(edge_list, Edge(parent_id, outcome_id)) - end - end - return edge_list, node_ids -end - -function DAG(scm::SCM) - edges_list, nodes_mapping = edges_list_and_nodes_mapping(scm) - return SimpleDiGraphFromIterator(edges_list), nodes_mapping -end - -function graphplot(scm::SCM; kwargs...) - graph, nodes_mapping = DAG(scm) - f, ax, p = graphplot(graph; - ilabels = [x[1] for x in sort(collect(nodes_mapping), by=x -> x[2])], - kwargs...) - Legend( - f[1, 2], - [MarkerElement(color=:black, marker='⋆', markersize = 10) for eq in equations(scm)], - [TMLE.string_repr(eq) for (_, eq) in equations(scm)]) - hidedecorations!(ax); - hidespines!(ax); - ax.aspect = DataAspect() - return fig, ax, p -end \ No newline at end of file diff --git a/src/TMLE.jl b/src/TMLE.jl index dd566ffa..f877a479 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -31,8 +31,8 @@ export AverageTreatmentEffect, ATE export InteractionAverageTreatmentEffect, IATE export AVAILABLE_ESTIMANDS export fit!, optimize_ordering, optimize_ordering! -export tmle!, tmle, ose, naive -export var, estimate, initial_estimate, OneSampleTTest, OneSampleZTest, pvalue, confint +export tmle!, tmle, ose, initial +export var, estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose export TreatmentTransformer, with_encoder export BackdoorAdjustment @@ -74,62 +74,57 @@ function run_precompile_workload() dataset = (C₁=C₁, W₁=W₁, W₂=W₂, T₁=T₁, T₂=T₂, Y₁=Y₁, Y₂=Y₂) @compile_workload begin - # all calls in this block will be precompiled, regardless of whether - # they belong to your package or not (on Julia 1.8 and higher) - Ψ = CM( + # SCM constructors + ## Incremental + scm = SCM() + push!(scm, SE(:Y₁, [:T₁, :W₁, :W₂], model=LinearRegressor())) + push!(scm, SE(:T₁, [:W₁, :W₂])) + setmodel!(scm.T₁, LinearBinaryClassifier()) + ## Implicit through estimand + for estimand_type in [CM, ATE, IATE] + estimand_type(outcome=:Y₁, treatment=(T₁=true,), confounders=[:W₁, :W₂]) + end + ## Complete + scm = SCM( + SE(:Y₁, [:T₁, :W₁, :W₂], model=with_encoder(LinearRegressor())), + SE(:T₁, [:W₁, :W₂],model=LinearBinaryClassifier()), + SE(:Y₂, [:T₁, :T₂, :W₁, :W₂, :C₁], model=with_encoder(LinearBinaryClassifier())), + SE(:T₂, [:W₁, :W₂],model=LinearBinaryClassifier()), + ) + + # Estimate some parameters + Ψ₁ = CM( + scm, outcome =:Y₁, treatment=(T₁=true,), - confounders=[:W₁, :W₂], - covariates = [:C₁] - ) - η_spec = NuisanceSpec( - LinearRegressor(), - LinearBinaryClassifier() ) - r, cache = tmle(Ψ, η_spec, dataset, verbosity=0) - continuous_estimands = [ - ATE( - outcome =:Y₁, - treatment=(T₁=(case=true, control=false),), - confounders=[:W₁, :W₂], - covariates = [:C₁] - ), - # IATE( - # outcome =:Y₁, - # treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), - # confounders=[:W₁, :W₂], - # covariates = [:C₁] - # ) - ] - cache = TMLECache(dataset) - for Ψ in continuous_estimands - tmle!(cache, Ψ, η_spec, verbosity=0) - end + result₁, fluctuation = tmle!(Ψ₁, dataset) - # Precompiling with a binary outcome - binary_estimands = [ - ATE( - outcome =:Y₂, - treatment=(T₁=(case=true, control=false),), - confounders=[:W₁, :W₂], - covariates = [:C₁] - ), - # IATE( - # outcome =:Y₂, - # treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), - # confounders=[:W₁, :W₂], - # covariates = [:C₁] - # ) - ] - η_spec = NuisanceSpec( - LinearBinaryClassifier(), - LinearBinaryClassifier() + Ψ₂ = ATE( + scm, + outcome=:Y₂, + treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)) ) + result₂, fluctuation = tmle!(Ψ₂, dataset) - for Ψ in binary_estimands - tmle!(cache, Ψ, η_spec, verbosity=0) - end - + Ψ₃ = IATE( + scm, + outcome=:Y₂, + treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)) + ) + result₃, fluctuation = tmle!(Ψ₃, dataset) + + # Composition + composed_result = compose((x,y) -> x - y, tmle(result₂), tmle(result₁)) + + # Results manipulation + initial(result) + OneSampleTTest(tmle(result₃)) + OneSampleZTest(tmle(result₃)) + OneSampleTTest(ose(result₃)) + OneSampleZTest(ose(result₃)) + OneSampleZTest(composed_result) + end end diff --git a/src/estimate.jl b/src/estimate.jl index da7f767c..9cec1caf 100644 --- a/src/estimate.jl +++ b/src/estimate.jl @@ -92,10 +92,20 @@ HypothesisTests.OneSampleTTest(r::AsymptoticallyLinearEstimate, Ψ₀=0) = OneSa Performs a T test on the ComposedTMLEstimate. """ function HypothesisTests.OneSampleTTest(r::ComposedTMLEstimate, Ψ₀=0) - @assert length(r.Ψ̂) > 1 "OneSampleTTest is only implemeted for real-valued statistics." + @assert length(r.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." return OneSampleTTest(estimate(r), sqrt(var(r)), 1, Ψ₀) end +""" + OneSampleZTest(r::ComposedTMLEstimate, Ψ₀=0) + +Performs a T test on the ComposedTMLEstimate. +""" +function HypothesisTests.OneSampleZTest(r::ComposedTMLEstimate, Ψ₀=0) + @assert length(r.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." + return OneSampleZTest(estimate(r), sqrt(var(r)), 1, Ψ₀) +end + """ compose(f, estimation_results::Vararg{AsymptoticallyLinearEstimate, N}) where N diff --git a/test/graphs_extension.jl b/test/graphs_extension.jl deleted file mode 100644 index 1324d15a..00000000 --- a/test/graphs_extension.jl +++ /dev/null @@ -1,20 +0,0 @@ -module TestGraphsTMLEExt - -using Test -using TMLE -using Graphs -using CairoMakie -using GraphMakie - -@testset "Test DAG and graphplot" begin - scm = SCM( - SE(:Y, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C]), - SE(:T₁, [:W₁₁, :W₁₂]), - SE(:T₂, [:W₂₁, :W₂₂]), - ) - dag, nodes_mapping = TMLE.DAG(scm) -end - -end - -true \ No newline at end of file diff --git a/test/non_regression_test.jl b/test/non_regression_test.jl index 49a2fe8e..f2160f4a 100644 --- a/test/non_regression_test.jl +++ b/test/non_regression_test.jl @@ -24,7 +24,7 @@ using MLJGLMInterface Ψ = ATE(scm, outcome=:haz01, treatment=(parity01=(case=1, control=0),)) result, fluctuation_mach = tmle!(Ψ, dataset, threshold=0.025, verbosity=0) - l, u = confint(OneSampleTTest(result.tmle)) + l, u = confint(OneSampleTTest(tmle(result))) @test estimate(result.tmle) ≈ -0.185533 atol = 1e-6 @test l ≈ -0.279246 atol = 1e-6 @test u ≈ -0.091821 atol = 1e-6 From e4138e5e2f45d007ed48231da4610ff789e43435 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 11 Aug 2023 14:13:57 +0100 Subject: [PATCH 039/151] update docs and tests --- docs/make.jl | 6 +++--- docs/src/user_guide/estimation.md | 1 + test/utils.jl | 24 ++++-------------------- 3 files changed, 8 insertions(+), 23 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index ae3c7c60..af708a93 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -31,9 +31,9 @@ makedocs(; "User Guide" => [joinpath("user_guide", f) for f in ("scm.md", "estimands.md", "estimation.md", "adjustment.md", "misc.md")], "Examples" => [ - #joinpath("examples", "introduction_to_targeted_learning.md"), - # joinpath("examples", "super_learning.md"), - # joinpath("examples", "double_robustness.md") + # joinpath("examples", "introduction_to_targeted_learning.md"), + joinpath("examples", "super_learning.md"), + joinpath("examples", "double_robustness.md") ], "Resources" => "resources.md", "API Reference" => "api.md" diff --git a/docs/src/user_guide/estimation.md b/docs/src/user_guide/estimation.md index 8761ae1d..c8805c0b 100644 --- a/docs/src/user_guide/estimation.md +++ b/docs/src/user_guide/estimation.md @@ -15,6 +15,7 @@ using CategoricalArrays using TMLE using LogExpFunctions using MLJLinearModels +using MLJ function make_dataset(;n=1000) rng = StableRNG(123) diff --git a/test/utils.jl b/test/utils.jl index 1379b6a4..22e46056 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -190,21 +190,13 @@ end @test w == bw == [1.75, 3.5, 7.0, 1.75, 1.75, 3.5, 1.75] mach = TMLE.tmle_step(Ψ, weighted_fluctuation=weighted_fluctuation) - @test fitted_params(mach) == ( - features = [:covariate], - coef = [0.12500000000000003], - intercept = 0.0 - ) + @test fitted_params(mach).coef[1] isa Float64 cov, w = TMLE.clever_covariate_and_weights(scm, W, T, indicator_fns) @test cov == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] unweighted_mach = TMLE.tmle_step(Ψ, weighted_fluctuation=false) - @test fitted_params(unweighted_mach) == ( - features = [:covariate], - coef = [0.047619047619047616], - intercept = 0.0 - ) + @test fitted_params(mach).coef[1] isa Float64 end @testset "Test Targeting sequence with 2 Treatments" begin @@ -243,11 +235,7 @@ end @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] mach = TMLE.tmle_step(Ψ) - @test fitted_params(mach) == ( - features = [:covariate], - coef = [-0.0945122832139119], - intercept = 0.0 - ) + @test fitted_params(mach).coef[1] isa Float64 end @testset "Test Targeting sequence with 3 Treatments" begin @@ -291,11 +279,7 @@ end @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] mach = TMLE.tmle_step(Ψ) - @test fitted_params(mach) == ( - features = [:covariate], - coef = [-0.02665556018325698], - intercept = 0.0 - ) + @test fitted_params(mach).coef[1] isa Float64 end @testset "Test fluctuation_input" begin From 90e2d109e86b4b11943c76f6d03fa710d1a977c3 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 11 Aug 2023 14:59:50 +0100 Subject: [PATCH 040/151] up deps and fix examples --- docs/Manifest.toml | 129 ++++++++++++++++++++++++++++++++-- docs/Project.toml | 4 +- examples/double_robustness.jl | 2 +- examples/super_learning.jl | 9 +-- src/estimands.jl | 61 +++++++++------- test/Project.toml | 2 - 6 files changed, 162 insertions(+), 45 deletions(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index bdd9fff0..ab769bf3 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,8 +1,8 @@ # This file is machine-generated - editing it directly is not advised -julia_version = "1.9.1" +julia_version = "1.9.2" manifest_format = "2.0" -project_hash = "82d74b11546c28748c1342b9b9b370ce23d4dd37" +project_hash = "65868c33a6e55148769c00ab2e5862d8c6fae92d" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" @@ -78,6 +78,12 @@ version = "7.4.11" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +[[deps.ArrayInterfaceCore]] +deps = ["LinearAlgebra", "SnoopPrecompile", "SparseArrays", "SuiteSparse"] +git-tree-sha1 = "e5f08b5689b1aad068e01751889f2f615c7db36d" +uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" +version = "0.1.29" + [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -124,6 +130,12 @@ git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d" uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" version = "0.1.7" +[[deps.BitTwiddlingConvenienceFunctions]] +deps = ["Static"] +git-tree-sha1 = "0c5f81f47bbbcf4aea7b2959135713459170798b" +uuid = "62783981-4cbd-42fc-bca8-16325de8dc4b" +version = "0.1.5" + [[deps.Bzip2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "19a35467a82e236ff51bc17a3a44b69ef35185a2" @@ -135,6 +147,12 @@ git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.4.2" +[[deps.CPUSummary]] +deps = ["CpuId", "IfElse", "PrecompileTools", "Static"] +git-tree-sha1 = "89e0654ed8c7aebad6d5ad235d6242c2d737a928" +uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" +version = "0.2.3" + [[deps.CRC32c]] uuid = "8bf52ea8-c179-5cab-976a-9e18b702a9bc" @@ -240,6 +258,12 @@ git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" version = "1.16.0" +[[deps.CloseOpenIntervals]] +deps = ["Static", "StaticArrayInterface"] +git-tree-sha1 = "70232f82ffaab9dc52585e0dd043b5e0c6b714f1" +uuid = "fb6a15b2-703c-40df-9091-08a04967cfa9" +version = "0.1.12" + [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] git-tree-sha1 = "02aa26a4cf76381be7f66e020a3eddeb27b0a092" @@ -300,7 +324,7 @@ weakdeps = ["Dates", "LinearAlgebra"] [[deps.CompilerSupportLibraries_jll]] deps = ["Artifacts", "Libdl"] uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.2+0" +version = "1.0.5+0" [[deps.ComputationalResources]] git-tree-sha1 = "52cb3ec90e8a8bea0e62e275ba577ad0f74821f7" @@ -329,6 +353,12 @@ git-tree-sha1 = "d05d9e7b7aedff4e5b51a029dced05cfb6125781" uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" version = "0.6.2" +[[deps.CpuId]] +deps = ["Markdown"] +git-tree-sha1 = "fcbb72b032692610bfbdb15018ac16a36cf2e406" +uuid = "adafc99b-e345-5852-983c-f28acb93d879" +version = "0.3.1" + [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" @@ -467,10 +497,10 @@ uuid = "90fa49ef-747e-5e6f-a989-263ba693cf1a" version = "0.5.2" [[deps.EvoTrees]] -deps = ["BSON", "CUDA", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "1b418518c0eb1fd1ef0a6d0bfc8051e6abb1232b" +deps = ["BSON", "CUDA", "CategoricalArrays", "Distributions", "LoopVectorization", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "fb09b634ba4b1c98ca319a1705df340d4b2005f0" uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" -version = "0.15.2" +version = "0.14.11" [[deps.ExactPredicates]] deps = ["IntervalArithmetic", "Random", "StaticArraysCore", "Test"] @@ -706,6 +736,12 @@ git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" version = "2.8.1+1" +[[deps.HostCPUFeatures]] +deps = ["BitTwiddlingConvenienceFunctions", "IfElse", "Libdl", "Static"] +git-tree-sha1 = "eb8fed28f4994600e29beef49744639d985a04b2" +uuid = "3e5b6fbb-0976-4d2c-9146-d79de83f2fb0" +version = "0.1.16" + [[deps.HypergeometricFunctions]] deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" @@ -718,6 +754,11 @@ git-tree-sha1 = "d75853a0bdbfb1ac815478bacd89cd27b550ace6" uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" version = "0.2.3" +[[deps.IfElse]] +git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" +uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +version = "0.1.1" + [[deps.ImageAxes]] deps = ["AxisArrays", "ImageBase", "ImageCore", "Reexport", "SimpleTraits"] git-tree-sha1 = "2e4520d67b0cef90865b3ef727594d2a58e0e1f8" @@ -928,6 +969,12 @@ git-tree-sha1 = "825289d43c753c7f1bf9bed334c253e9913997f8" uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d" version = "1.9.0" +[[deps.LayoutPointers]] +deps = ["ArrayInterface", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface"] +git-tree-sha1 = "88b8f66b604da079a627b6fb2860d3704a6729a1" +uuid = "10f19ff3-798f-405d-979b-55457f8fc047" +version = "0.1.14" + [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" @@ -1060,6 +1107,17 @@ git-tree-sha1 = "cedb76b37bc5a6c702ade66be44f831fa23c681e" uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" version = "1.0.0" +[[deps.LoopVectorization]] +deps = ["ArrayInterface", "ArrayInterfaceCore", "CPUSummary", "CloseOpenIntervals", "DocStringExtensions", "HostCPUFeatures", "IfElse", "LayoutPointers", "LinearAlgebra", "OffsetArrays", "PolyesterWeave", "PrecompileTools", "SIMDTypes", "SLEEFPirates", "Static", "StaticArrayInterface", "ThreadingUtilities", "UnPack", "VectorizationBase"] +git-tree-sha1 = "c88a4afe1703d731b1c4fdf4e3c7e77e3b176ea2" +uuid = "bdcacae8-1622-11e9-2a5c-532679323890" +version = "0.12.165" +weakdeps = ["ChainRulesCore", "ForwardDiff", "SpecialFunctions"] + + [deps.LoopVectorization.extensions] + ForwardDiffExt = ["ChainRulesCore", "ForwardDiff"] + SpecialFunctionsExt = "SpecialFunctions" + [[deps.LossFunctions]] deps = ["Markdown", "Requires", "Statistics"] git-tree-sha1 = "c2b72b61d2e3489b1f9cae3403226b21ec90c943" @@ -1148,6 +1206,11 @@ git-tree-sha1 = "87a85ff81583bd392642869557cb633532989517" uuid = "20f20a25-4f0e-4fdf-b5d1-57303727442b" version = "0.6.4" +[[deps.ManualMemory]] +git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" +uuid = "d125e4d3-2237-4719-b19c-fa641b8a4667" +version = "0.1.8" + [[deps.MappedArrays]] git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" @@ -1390,7 +1453,7 @@ version = "0.42.2+0" [[deps.Pkg]] deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.0" +version = "1.9.2" [[deps.PkgVersion]] deps = ["Pkg"] @@ -1404,6 +1467,12 @@ git-tree-sha1 = "f92e1315dadf8c46561fb9396e525f7200cdc227" uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" version = "1.3.5" +[[deps.PolyesterWeave]] +deps = ["BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "Static", "ThreadingUtilities"] +git-tree-sha1 = "240d7170f5ffdb285f9427b92333c3463bf65bf6" +uuid = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad" +version = "0.2.1" + [[deps.PolygonOps]] git-tree-sha1 = "77b3d3605fc1cd0b42d95eba87dfcd2bf67d5ff6" uuid = "647866c9-e3ac-4575-94e7-e3d426903924" @@ -1573,6 +1642,17 @@ version = "0.2.1" uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" +[[deps.SIMDTypes]] +git-tree-sha1 = "330289636fb8107c5f32088d2741e9fd7a061a5c" +uuid = "94e857df-77ce-4151-89e5-788b33177be4" +version = "0.1.0" + +[[deps.SLEEFPirates]] +deps = ["IfElse", "Static", "VectorizationBase"] +git-tree-sha1 = "4b8586aece42bee682399c4c4aee95446aa5cd19" +uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa" +version = "0.6.39" + [[deps.SQLite]] deps = ["DBInterface", "Random", "SQLite_jll", "Serialization", "Tables", "WeakRefStrings"] git-tree-sha1 = "eb9a473c9b191ced349d04efa612ec9f39c087ea" @@ -1690,6 +1770,12 @@ git-tree-sha1 = "2da10356e31327c7096832eb9cd86307a50b1eb6" uuid = "45858cf5-a6b0-47a3-bbea-62219f50df47" version = "0.1.3" +[[deps.SnoopPrecompile]] +deps = ["Preferences"] +git-tree-sha1 = "e760a70afdcd461cf01a575947738d359234665c" +uuid = "66db9d55-30c0-4569-8b51-7e840670fc0c" +version = "1.0.3" + [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -1731,6 +1817,23 @@ git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" version = "0.1.1" +[[deps.Static]] +deps = ["IfElse"] +git-tree-sha1 = "f295e0a1da4ca425659c57441bcb59abb035a4bc" +uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" +version = "0.8.8" + +[[deps.StaticArrayInterface]] +deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "Requires", "SnoopPrecompile", "SparseArrays", "Static", "SuiteSparse"] +git-tree-sha1 = "33040351d2403b84afce74dae2e22d3f5b18edcb" +uuid = "0d7ed370-da01-4f52-bd93-41d350b8b718" +version = "1.4.0" +weakdeps = ["OffsetArrays", "StaticArrays"] + + [deps.StaticArrayInterface.extensions] + StaticArrayInterfaceOffsetArraysExt = "OffsetArrays" + StaticArrayInterfaceStaticArraysExt = "StaticArrays" + [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "StaticArraysCore"] git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881" @@ -1841,6 +1944,12 @@ version = "0.1.1" deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +[[deps.ThreadingUtilities]] +deps = ["ManualMemory"] +git-tree-sha1 = "eda08f7e9818eb53661b3deb74e3159460dfbc27" +uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" +version = "0.5.2" + [[deps.TiffImages]] deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "UUIDs"] git-tree-sha1 = "8621f5c499a8aa4aa970b1ae381aae0ef1576966" @@ -1903,6 +2012,12 @@ git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" version = "0.1.3" +[[deps.VectorizationBase]] +deps = ["ArrayInterface", "CPUSummary", "HostCPUFeatures", "IfElse", "LayoutPointers", "Libdl", "LinearAlgebra", "SIMDTypes", "Static", "StaticArrayInterface"] +git-tree-sha1 = "b182207d4af54ac64cbc71797765068fdeff475d" +uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" +version = "0.21.64" + [[deps.WeakRefStrings]] deps = ["DataAPI", "InlineStrings", "Parsers"] git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" diff --git a/docs/Project.toml b/docs/Project.toml index 18f38b85..c5e5ccd0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -20,15 +20,15 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] +CSV = "0.10.11" CairoMakie = "0.10.4" DataFrames = "1.5" Distributions = "0.25" Documenter = "0.27" -EvoTrees = "0.15" +EvoTrees = "0.14" Literate = "2.13" MLJ = "0.19.1" MLJLinearModels = "0.9.1" NearestNeighborModels = "0.2" SQLite = "1.4" -CSV = "0.10.11" StableRNGs = "1.0" diff --git a/examples/double_robustness.jl b/examples/double_robustness.jl index 8b166cc1..c0efdf7c 100644 --- a/examples/double_robustness.jl +++ b/examples/double_robustness.jl @@ -32,7 +32,7 @@ using CairoMakie function generate_data(;n = 1000, rng = StableRNG(123)) W = rand(rng, Normal(), n) μT = logistic.(0.3 .- 0.5W) - T = float(rand(rng, Uniform(), n) .< μT) + T = float(rand(rng, n) .< μT) ϵ = rand(rng, Normal(), n) Y = μY(T, W) .+ ϵ Y₁ = μY(ones(n), W) .+ ϵ diff --git a/examples/super_learning.jl b/examples/super_learning.jl index 75d410ff..a585e556 100644 --- a/examples/super_learning.jl +++ b/examples/super_learning.jl @@ -26,7 +26,6 @@ categorical values. using CSV using DataFrames using TMLE -using CairoMakie using MLJ dataset = CSV.read( @@ -35,8 +34,8 @@ dataset = CSV.read( select=[:haz01, :parity01, :apgar1, :apgar5, :gagebrth, :mage, :meducyrs, :sexn], types=Float64 ) -dataset.haz01 = categorical(dataset.haz01) -dataset.parity01 = categorical(dataset.parity01) +dataset.haz01 = categorical(Int.(dataset.haz01)) +dataset.parity01 = categorical(Int.(dataset.parity01)) nothing # hide #= @@ -134,9 +133,5 @@ tmle_result, _ = tmle!(Ψ, dataset) tmle_result -# -using Test # hide -pvalue(test_result) > 0.05 #hide -nothing # hide diff --git a/src/estimands.jl b/src/estimands.jl index 0f726232..0bb965ac 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -19,18 +19,21 @@ abstract type Estimand end ## Constructors -- CM(;scm::SCM, outcome, treatment) -- CM(scm::SCM; outcome, treatment) - -where: - -- scm: is a `StructuralCausalModel` (see [`SCM`](@ref)) -- outcome: is a `Symbol` -- treatment: is a `NamedTuple` - +- CM(scm::SCM; outcome::Symbol, treatment::NamedTuple) +- CM( + outcome::Symbol, + treatment::NamedTuple, + confounders::Union{Symbol, AbstractVector{Symbol}}, + covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, + outcome_model = with_encoder(LinearRegressor()), + treatment_model = LinearBinaryClassifier() +) ## Example +```julia Ψ = CM(scm, outcome=:Y, treatment=(T=1,)) + +``` """ struct ConditionalMean <: Estimand scm::StructuralCausalModel @@ -56,18 +59,21 @@ const CM = ConditionalMean ## Constructors -- ATE(;scm::SCM, outcome, treatment) -- ATE(scm::SCM; outcome, treatment) - -where: - -- scm: is a `StructuralCausalModel` (see [`SCM`](@ref)) -- outcome: is a `Symbol` -- treatment: is a `NamedTuple` +- ATE(scm::SCM; outcome::Symbol, treatment::NamedTuple) +- ATE( + outcome::Symbol, + treatment::NamedTuple, + confounders::Union{Symbol, AbstractVector{Symbol}}, + covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, + outcome_model = with_encoder(LinearRegressor()), + treatment_model = LinearBinaryClassifier() +) ## Example +```julia Ψ = ATE(scm, outcome=:Y, treatment=(T=(case=1,control=0),) +``` """ struct AverageTreatmentEffect <: Estimand scm::StructuralCausalModel @@ -90,24 +96,27 @@ const ATE = AverageTreatmentEffect ## Definition -For two treatments with settings (1, 0): +For two treatments with case/control settings (1, 0): ``IATE = E[Y|do(T₁=1, T₂=1)] - E[Y|do(T₁=1, T₂=0)] - E[Y|do(T₁=0, T₂=1)] + E[Y|do(T₁=0, T₂=0)]`` ## Constructors -- IATE(;scm::SCM, outcome, treatment) -- IATE(scm::SCM; outcome, treatment) - -where: - -- scm: is a `StructuralCausalModel` (see [`SCM`](@ref)) -- outcome: is a `Symbol` -- treatment: is a `NamedTuple` +- IATE(scm::SCM; outcome::Symbol, treatment::NamedTuple) +- IATE( + outcome::Symbol, + treatment::NamedTuple, + confounders::Union{Symbol, AbstractVector{Symbol}}, + covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, + outcome_model = with_encoder(LinearRegressor()), + treatment_model = LinearBinaryClassifier() +) ## Example +```julia Ψ = IATE(scm, outcome=:Y, treatment=(T₁=(case=1,control=0), T₂=(case=1,control=0)) +``` """ struct InteractionAverageTreatmentEffect <: Estimand scm::StructuralCausalModel diff --git a/test/Project.toml b/test/Project.toml index 5633e74d..d9942ee4 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,7 +5,6 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" @@ -25,7 +24,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] CSV = "0.10" DataFrames = "1.5" -EvoTrees = "0.15" MLJLinearModels = "0.9" StableRNGs = "1.0" StatsBase = "0.33" From 5fa1be442946ea1c08312abff67915bdaa54ffec Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 13 Aug 2023 14:11:38 +0100 Subject: [PATCH 041/151] update evotree dep --- docs/Manifest.toml | 123 ++------------------------------------------- docs/Project.toml | 2 +- 2 files changed, 5 insertions(+), 120 deletions(-) diff --git a/docs/Manifest.toml b/docs/Manifest.toml index ab769bf3..39087c80 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.9.2" manifest_format = "2.0" -project_hash = "65868c33a6e55148769c00ab2e5862d8c6fae92d" +project_hash = "0251892739390e1f6b399b7b0a8ba56b014a5eee" [[deps.ANSIColoredPrinters]] git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" @@ -78,12 +78,6 @@ version = "7.4.11" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -[[deps.ArrayInterfaceCore]] -deps = ["LinearAlgebra", "SnoopPrecompile", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "e5f08b5689b1aad068e01751889f2f615c7db36d" -uuid = "30b0a656-2188-435a-8636-2ec0e6a096e2" -version = "0.1.29" - [[deps.Artifacts]] uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" @@ -130,12 +124,6 @@ git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d" uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" version = "0.1.7" -[[deps.BitTwiddlingConvenienceFunctions]] -deps = ["Static"] -git-tree-sha1 = "0c5f81f47bbbcf4aea7b2959135713459170798b" -uuid = "62783981-4cbd-42fc-bca8-16325de8dc4b" -version = "0.1.5" - [[deps.Bzip2_jll]] deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] git-tree-sha1 = "19a35467a82e236ff51bc17a3a44b69ef35185a2" @@ -147,12 +135,6 @@ git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" version = "0.4.2" -[[deps.CPUSummary]] -deps = ["CpuId", "IfElse", "PrecompileTools", "Static"] -git-tree-sha1 = "89e0654ed8c7aebad6d5ad235d6242c2d737a928" -uuid = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9" -version = "0.2.3" - [[deps.CRC32c]] uuid = "8bf52ea8-c179-5cab-976a-9e18b702a9bc" @@ -258,12 +240,6 @@ git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" version = "1.16.0" -[[deps.CloseOpenIntervals]] -deps = ["Static", "StaticArrayInterface"] -git-tree-sha1 = "70232f82ffaab9dc52585e0dd043b5e0c6b714f1" -uuid = "fb6a15b2-703c-40df-9091-08a04967cfa9" -version = "0.1.12" - [[deps.CodecZlib]] deps = ["TranscodingStreams", "Zlib_jll"] git-tree-sha1 = "02aa26a4cf76381be7f66e020a3eddeb27b0a092" @@ -353,12 +329,6 @@ git-tree-sha1 = "d05d9e7b7aedff4e5b51a029dced05cfb6125781" uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" version = "0.6.2" -[[deps.CpuId]] -deps = ["Markdown"] -git-tree-sha1 = "fcbb72b032692610bfbdb15018ac16a36cf2e406" -uuid = "adafc99b-e345-5852-983c-f28acb93d879" -version = "0.3.1" - [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" @@ -497,10 +467,10 @@ uuid = "90fa49ef-747e-5e6f-a989-263ba693cf1a" version = "0.5.2" [[deps.EvoTrees]] -deps = ["BSON", "CUDA", "CategoricalArrays", "Distributions", "LoopVectorization", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "fb09b634ba4b1c98ca319a1705df340d4b2005f0" +deps = ["BSON", "CUDA", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] +git-tree-sha1 = "5023442c1f797c0fd6677b1a1886ab44f43f3378" uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" -version = "0.14.11" +version = "0.16.0" [[deps.ExactPredicates]] deps = ["IntervalArithmetic", "Random", "StaticArraysCore", "Test"] @@ -736,12 +706,6 @@ git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" version = "2.8.1+1" -[[deps.HostCPUFeatures]] -deps = ["BitTwiddlingConvenienceFunctions", "IfElse", "Libdl", "Static"] -git-tree-sha1 = "eb8fed28f4994600e29beef49744639d985a04b2" -uuid = "3e5b6fbb-0976-4d2c-9146-d79de83f2fb0" -version = "0.1.16" - [[deps.HypergeometricFunctions]] deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" @@ -754,11 +718,6 @@ git-tree-sha1 = "d75853a0bdbfb1ac815478bacd89cd27b550ace6" uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" version = "0.2.3" -[[deps.IfElse]] -git-tree-sha1 = "debdd00ffef04665ccbb3e150747a77560e8fad1" -uuid = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" -version = "0.1.1" - [[deps.ImageAxes]] deps = ["AxisArrays", "ImageBase", "ImageCore", "Reexport", "SimpleTraits"] git-tree-sha1 = "2e4520d67b0cef90865b3ef727594d2a58e0e1f8" @@ -969,12 +928,6 @@ git-tree-sha1 = "825289d43c753c7f1bf9bed334c253e9913997f8" uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d" version = "1.9.0" -[[deps.LayoutPointers]] -deps = ["ArrayInterface", "LinearAlgebra", "ManualMemory", "SIMDTypes", "Static", "StaticArrayInterface"] -git-tree-sha1 = "88b8f66b604da079a627b6fb2860d3704a6729a1" -uuid = "10f19ff3-798f-405d-979b-55457f8fc047" -version = "0.1.14" - [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" @@ -1107,17 +1060,6 @@ git-tree-sha1 = "cedb76b37bc5a6c702ade66be44f831fa23c681e" uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" version = "1.0.0" -[[deps.LoopVectorization]] -deps = ["ArrayInterface", "ArrayInterfaceCore", "CPUSummary", "CloseOpenIntervals", "DocStringExtensions", "HostCPUFeatures", "IfElse", "LayoutPointers", "LinearAlgebra", "OffsetArrays", "PolyesterWeave", "PrecompileTools", "SIMDTypes", "SLEEFPirates", "Static", "StaticArrayInterface", "ThreadingUtilities", "UnPack", "VectorizationBase"] -git-tree-sha1 = "c88a4afe1703d731b1c4fdf4e3c7e77e3b176ea2" -uuid = "bdcacae8-1622-11e9-2a5c-532679323890" -version = "0.12.165" -weakdeps = ["ChainRulesCore", "ForwardDiff", "SpecialFunctions"] - - [deps.LoopVectorization.extensions] - ForwardDiffExt = ["ChainRulesCore", "ForwardDiff"] - SpecialFunctionsExt = "SpecialFunctions" - [[deps.LossFunctions]] deps = ["Markdown", "Requires", "Statistics"] git-tree-sha1 = "c2b72b61d2e3489b1f9cae3403226b21ec90c943" @@ -1206,11 +1148,6 @@ git-tree-sha1 = "87a85ff81583bd392642869557cb633532989517" uuid = "20f20a25-4f0e-4fdf-b5d1-57303727442b" version = "0.6.4" -[[deps.ManualMemory]] -git-tree-sha1 = "bcaef4fc7a0cfe2cba636d84cda54b5e4e4ca3cd" -uuid = "d125e4d3-2237-4719-b19c-fa641b8a4667" -version = "0.1.8" - [[deps.MappedArrays]] git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" @@ -1467,12 +1404,6 @@ git-tree-sha1 = "f92e1315dadf8c46561fb9396e525f7200cdc227" uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" version = "1.3.5" -[[deps.PolyesterWeave]] -deps = ["BitTwiddlingConvenienceFunctions", "CPUSummary", "IfElse", "Static", "ThreadingUtilities"] -git-tree-sha1 = "240d7170f5ffdb285f9427b92333c3463bf65bf6" -uuid = "1d0040c9-8b98-4ee7-8388-3f51789ca0ad" -version = "0.2.1" - [[deps.PolygonOps]] git-tree-sha1 = "77b3d3605fc1cd0b42d95eba87dfcd2bf67d5ff6" uuid = "647866c9-e3ac-4575-94e7-e3d426903924" @@ -1642,17 +1573,6 @@ version = "0.2.1" uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" version = "0.7.0" -[[deps.SIMDTypes]] -git-tree-sha1 = "330289636fb8107c5f32088d2741e9fd7a061a5c" -uuid = "94e857df-77ce-4151-89e5-788b33177be4" -version = "0.1.0" - -[[deps.SLEEFPirates]] -deps = ["IfElse", "Static", "VectorizationBase"] -git-tree-sha1 = "4b8586aece42bee682399c4c4aee95446aa5cd19" -uuid = "476501e8-09a2-5ece-8869-fb82de89a1fa" -version = "0.6.39" - [[deps.SQLite]] deps = ["DBInterface", "Random", "SQLite_jll", "Serialization", "Tables", "WeakRefStrings"] git-tree-sha1 = "eb9a473c9b191ced349d04efa612ec9f39c087ea" @@ -1770,12 +1690,6 @@ git-tree-sha1 = "2da10356e31327c7096832eb9cd86307a50b1eb6" uuid = "45858cf5-a6b0-47a3-bbea-62219f50df47" version = "0.1.3" -[[deps.SnoopPrecompile]] -deps = ["Preferences"] -git-tree-sha1 = "e760a70afdcd461cf01a575947738d359234665c" -uuid = "66db9d55-30c0-4569-8b51-7e840670fc0c" -version = "1.0.3" - [[deps.Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" @@ -1817,23 +1731,6 @@ git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" version = "0.1.1" -[[deps.Static]] -deps = ["IfElse"] -git-tree-sha1 = "f295e0a1da4ca425659c57441bcb59abb035a4bc" -uuid = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" -version = "0.8.8" - -[[deps.StaticArrayInterface]] -deps = ["ArrayInterface", "Compat", "IfElse", "LinearAlgebra", "Requires", "SnoopPrecompile", "SparseArrays", "Static", "SuiteSparse"] -git-tree-sha1 = "33040351d2403b84afce74dae2e22d3f5b18edcb" -uuid = "0d7ed370-da01-4f52-bd93-41d350b8b718" -version = "1.4.0" -weakdeps = ["OffsetArrays", "StaticArrays"] - - [deps.StaticArrayInterface.extensions] - StaticArrayInterfaceOffsetArraysExt = "OffsetArrays" - StaticArrayInterfaceStaticArraysExt = "StaticArrays" - [[deps.StaticArrays]] deps = ["LinearAlgebra", "Random", "StaticArraysCore"] git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881" @@ -1944,12 +1841,6 @@ version = "0.1.1" deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" -[[deps.ThreadingUtilities]] -deps = ["ManualMemory"] -git-tree-sha1 = "eda08f7e9818eb53661b3deb74e3159460dfbc27" -uuid = "8290d209-cae3-49c0-8002-c8c24d57dab5" -version = "0.5.2" - [[deps.TiffImages]] deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "UUIDs"] git-tree-sha1 = "8621f5c499a8aa4aa970b1ae381aae0ef1576966" @@ -2012,12 +1903,6 @@ git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" version = "0.1.3" -[[deps.VectorizationBase]] -deps = ["ArrayInterface", "CPUSummary", "HostCPUFeatures", "IfElse", "LayoutPointers", "Libdl", "LinearAlgebra", "SIMDTypes", "Static", "StaticArrayInterface"] -git-tree-sha1 = "b182207d4af54ac64cbc71797765068fdeff475d" -uuid = "3d5dd08c-fd9d-11e8-17fa-ed2836048c2f" -version = "0.21.64" - [[deps.WeakRefStrings]] deps = ["DataAPI", "InlineStrings", "Parsers"] git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" diff --git a/docs/Project.toml b/docs/Project.toml index c5e5ccd0..da78bb9f 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -25,7 +25,7 @@ CairoMakie = "0.10.4" DataFrames = "1.5" Distributions = "0.25" Documenter = "0.27" -EvoTrees = "0.14" +EvoTrees = "0.16" Literate = "2.13" MLJ = "0.19.1" MLJLinearModels = "0.9.1" From e5cbefe5646abf848a43b3fa22ac1e4b00ced8e6 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 13 Aug 2023 14:28:54 +0100 Subject: [PATCH 042/151] remove extension for now --- ext/GraphsTMLEExt.jl | 50 ++++++++++++++++++++++++++++++++++++++++++++ test/Project.toml | 3 --- 2 files changed, 50 insertions(+), 3 deletions(-) create mode 100644 ext/GraphsTMLEExt.jl diff --git a/ext/GraphsTMLEExt.jl b/ext/GraphsTMLEExt.jl new file mode 100644 index 00000000..fca1ae34 --- /dev/null +++ b/ext/GraphsTMLEExt.jl @@ -0,0 +1,50 @@ +module GraphsTMLEExt + +using TMLE +using Graphs +using CairoMakie +using GraphMakie + +function edges_list_and_nodes_mapping(scm::SCM) + edge_list = [] + node_ids = Dict() + node_id = 1 + for (outcome, eq) in equations(scm) + if outcome ∉ keys(node_ids) + push!(node_ids, outcome => node_id) + node_id += 1 + end + outcome_id = node_ids[outcome] + for parent in TMLE.parents(eq) + if parent ∉ keys(node_ids) + push!(node_ids, parent => node_id) + node_id += 1 + end + parent_id = node_ids[parent] + push!(edge_list, Edge(parent_id, outcome_id)) + end + end + return edge_list, node_ids +end + +function DAG(scm::SCM) + edges_list, nodes_mapping = edges_list_and_nodes_mapping(scm) + return SimpleDiGraphFromIterator(edges_list), nodes_mapping +end + +function graphplot(scm::SCM; kwargs...) + graph, nodes_mapping = DAG(scm) + f, ax, p = graphplot(graph; + ilabels = [x[1] for x in sort(collect(nodes_mapping), by=x -> x[2])], + kwargs...) + Legend( + f[1, 2], + [MarkerElement(color=:black, marker='⋆', markersize = 10) for eq in equations(scm)], + [TMLE.string_repr(eq) for (_, eq) in equations(scm)]) + hidedecorations!(ax); + hidespines!(ax); + ax.aspect = DataAspect() + return fig, ax, p +end + +end \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index d9942ee4..c9b22e77 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,12 +1,9 @@ [deps] CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" -Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" From 910bf04f0f5b453a1699c34a3604bef9768712b7 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 13 Aug 2023 14:40:08 +0100 Subject: [PATCH 043/151] actually remove ext --- ext/GraphsTMLEExt.jl | 50 -------------------------------------------- 1 file changed, 50 deletions(-) delete mode 100644 ext/GraphsTMLEExt.jl diff --git a/ext/GraphsTMLEExt.jl b/ext/GraphsTMLEExt.jl deleted file mode 100644 index fca1ae34..00000000 --- a/ext/GraphsTMLEExt.jl +++ /dev/null @@ -1,50 +0,0 @@ -module GraphsTMLEExt - -using TMLE -using Graphs -using CairoMakie -using GraphMakie - -function edges_list_and_nodes_mapping(scm::SCM) - edge_list = [] - node_ids = Dict() - node_id = 1 - for (outcome, eq) in equations(scm) - if outcome ∉ keys(node_ids) - push!(node_ids, outcome => node_id) - node_id += 1 - end - outcome_id = node_ids[outcome] - for parent in TMLE.parents(eq) - if parent ∉ keys(node_ids) - push!(node_ids, parent => node_id) - node_id += 1 - end - parent_id = node_ids[parent] - push!(edge_list, Edge(parent_id, outcome_id)) - end - end - return edge_list, node_ids -end - -function DAG(scm::SCM) - edges_list, nodes_mapping = edges_list_and_nodes_mapping(scm) - return SimpleDiGraphFromIterator(edges_list), nodes_mapping -end - -function graphplot(scm::SCM; kwargs...) - graph, nodes_mapping = DAG(scm) - f, ax, p = graphplot(graph; - ilabels = [x[1] for x in sort(collect(nodes_mapping), by=x -> x[2])], - kwargs...) - Legend( - f[1, 2], - [MarkerElement(color=:black, marker='⋆', markersize = 10) for eq in equations(scm)], - [TMLE.string_repr(eq) for (_, eq) in equations(scm)]) - hidedecorations!(ax); - hidespines!(ax); - ax.aspect = DataAspect() - return fig, ax, p -end - -end \ No newline at end of file From 7513f51c6b5efca96930f882ea72ea4267d608e7 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 13 Aug 2023 15:21:25 +0100 Subject: [PATCH 044/151] more tests --- src/estimate.jl | 48 ++++++++++++++++++++++++------------- test/composition.jl | 33 +++++++++++++++++++++---- test/non_regression_test.jl | 15 ++++++++++-- 3 files changed, 73 insertions(+), 23 deletions(-) diff --git a/src/estimate.jl b/src/estimate.jl index 9cec1caf..1ebd3adf 100644 --- a/src/estimate.jl +++ b/src/estimate.jl @@ -11,18 +11,32 @@ struct OSEstimate{T<:AbstractFloat} <: AsymptoticallyLinearEstimate IC::Vector{T} end -struct ComposedTMLEstimate{T<:AbstractFloat} <: AsymptoticallyLinearEstimate +struct ComposedEstimate{T<:AbstractFloat} <: AsymptoticallyLinearEstimate Ψ̂::Array{T} σ̂::Matrix{T} + n::Int end -struct TMLEResult{P <: Estimand, T<:AbstractFloat} +struct TMLEResult{P <: Estimand, T<: AbstractFloat} estimand::P tmle::TMLEstimate{T} onestep::OSEstimate{T} initial::T end +function Base.show(io::IO, ::MIME"text/plain", est::ComposedEstimate) + if length(est.σ̂) !== 1 + println(io, string("Estimate: ", estimate(est), "\nVariance: \n", var(est))) + else + testresult = OneSampleTTest(est) + data = [estimate(est) confint(testresult) pvalue(testresult);] + headers = ["Estimate", "95% Confidence Interval", "P-value"] + pretty_table(io, data;header=headers) + end +end + + + function Base.show(io::IO, ::MIME"text/plain", est::AsymptoticallyLinearEstimate) testresult = OneSampleTTest(est) data = [estimate(est) confint(testresult) pvalue(testresult);] @@ -52,11 +66,11 @@ Retrieves the final estimate: after the TMLE step. Distributions.estimate(r::AsymptoticallyLinearEstimate) = r.Ψ̂ """ - Distributions.estimate(r::ComposedTMLEstimate) + Distributions.estimate(r::ComposedEstimate) Retrieves the final estimate: after the TMLE step. """ -Distributions.estimate(r::ComposedTMLEstimate) = length(r.Ψ̂) == 1 ? r.Ψ̂[1] : r.Ψ̂ +Distributions.estimate(r::ComposedEstimate) = length(r.Ψ̂) == 1 ? r.Ψ̂[1] : r.Ψ̂ """ var(r::AsymptoticallyLinearEstimate) @@ -66,11 +80,12 @@ Computes the estimated variance associated with the estimate. Statistics.var(r::AsymptoticallyLinearEstimate) = var(r.IC)/size(r.IC, 1) """ - var(r::ComposedTMLEstimate) + var(r::ComposedEstimate) Computes the estimated variance associated with the estimate. """ -Statistics.var(r::ComposedTMLEstimate) = length(r.σ̂) == 1 ? r.σ̂[1] : r.σ̂ +Statistics.var(r::ComposedEstimate) = length(r.σ̂) == 1 ? r.σ̂[1] / r.n : r.σ̂ ./ r.n + """ OneSampleZTest(r::AsymptoticallyLinearEstimate, Ψ₀=0) @@ -87,26 +102,25 @@ Performs a T test on the AsymptoticallyLinearEstimate. HypothesisTests.OneSampleTTest(r::AsymptoticallyLinearEstimate, Ψ₀=0) = OneSampleTTest(estimate(r), std(r.IC), size(r.IC, 1), Ψ₀) """ - OneSampleTTest(r::ComposedTMLEstimate, Ψ₀=0) + OneSampleTTest(r::ComposedEstimate, Ψ₀=0) -Performs a T test on the ComposedTMLEstimate. +Performs a T test on the ComposedEstimate. """ -function HypothesisTests.OneSampleTTest(r::ComposedTMLEstimate, Ψ₀=0) +function HypothesisTests.OneSampleTTest(r::ComposedEstimate, Ψ₀=0) @assert length(r.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." - return OneSampleTTest(estimate(r), sqrt(var(r)), 1, Ψ₀) + return OneSampleTTest(estimate(r), sqrt(r.σ̂[1]), r.n, Ψ₀) end """ - OneSampleZTest(r::ComposedTMLEstimate, Ψ₀=0) + OneSampleZTest(r::ComposedEstimate, Ψ₀=0) -Performs a T test on the ComposedTMLEstimate. +Performs a T test on the ComposedEstimate. """ -function HypothesisTests.OneSampleZTest(r::ComposedTMLEstimate, Ψ₀=0) +function HypothesisTests.OneSampleZTest(r::ComposedEstimate, Ψ₀=0) @assert length(r.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." - return OneSampleZTest(estimate(r), sqrt(var(r)), 1, Ψ₀) + return OneSampleZTest(estimate(r), sqrt(r.σ̂[1]), r.n, Ψ₀) end - """ compose(f, estimation_results::Vararg{AsymptoticallyLinearEstimate, N}) where N @@ -158,8 +172,8 @@ function compose(f, estimators::Vararg{AsymptoticallyLinearEstimate, N}; backend f₀, Js = AD.value_and_jacobian(backend, f, estimates...) J = hcat(Js...) n = size(first(estimators).IC, 1) - σ₀ = (J*Σ*J')/n - return ComposedTMLEstimate(collect(f₀), σ₀) + σ₀ = J*Σ*J' + return ComposedEstimate(collect(f₀), σ₀, n) end function Statistics.cov(estimators::Vararg{AsymptoticallyLinearEstimate, N}) where N diff --git a/test/composition.jl b/test/composition.jl index 393a38d5..eb9a99e9 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -48,8 +48,10 @@ end treatment = (T=0,) ) CM_result₀, _ = tmle!(CM₀, dataset, verbosity=0) - # Via Composition - CM_result_composed = compose(-, CM_result₁.tmle, CM_result₀.tmle) + # Composition of TMLE + + CM_result_composed_tmle = compose(-, tmle(CM_result₁), tmle(CM_result₀)); + CM_result_composed_ose = compose(-, ose(CM_result₁), ose(CM_result₀)); # Via ATE ATE₁₀ = ATE( @@ -58,8 +60,31 @@ end treatment = (T=(case=1, control=0),), ) ATE_result₁₀, _ = tmle!(ATE₁₀, dataset, verbosity=0) - @test estimate(ATE_result₁₀.tmle) ≈ estimate(CM_result_composed) atol = 1e-7 - @test var(ATE_result₁₀.tmle) ≈ var(CM_result_composed) atol = 1e-7 + # Check composed TMLE + @test estimate(tmle(ATE_result₁₀)) ≈ estimate(CM_result_composed_tmle) atol = 1e-7 + # T Test + composed_confint = collect(confint(OneSampleTTest(CM_result_composed_tmle))) + tmle_confint = collect(confint(OneSampleTTest(tmle(ATE_result₁₀)))) + @test tmle_confint ≈ composed_confint atol=1e-4 + # Z Test + composed_confint = collect(confint(OneSampleZTest(CM_result_composed_tmle))) + tmle_confint = collect(confint(OneSampleZTest(tmle(ATE_result₁₀)))) + @test tmle_confint ≈ composed_confint atol=1e-4 + # Variance + @test var(tmle(ATE_result₁₀)) ≈ var(CM_result_composed_tmle) atol = 1e-3 + + # Check composed OSE + @test estimate(ose(ATE_result₁₀)) ≈ estimate(CM_result_composed_ose) atol = 1e-7 + # T Test + composed_confint = collect(confint(OneSampleTTest(CM_result_composed_ose))) + ose_confint = collect(confint(OneSampleTTest(ose(ATE_result₁₀)))) + @test ose_confint ≈ composed_confint atol=1e-4 + # Z Test + composed_confint = collect(confint(OneSampleZTest(CM_result_composed_tmle))) + ose_confint = collect(confint(OneSampleZTest(ose(ATE_result₁₀)))) + @test ose_confint ≈ composed_confint atol=1e-4 + # Variance + @test var(ose(ATE_result₁₀)) ≈ var(CM_result_composed_ose) atol = 1e-3 end @testset "Test compose multidimensional function" begin diff --git a/test/non_regression_test.jl b/test/non_regression_test.jl index f2160f4a..9814c140 100644 --- a/test/non_regression_test.jl +++ b/test/non_regression_test.jl @@ -24,10 +24,21 @@ using MLJGLMInterface Ψ = ATE(scm, outcome=:haz01, treatment=(parity01=(case=1, control=0),)) result, fluctuation_mach = tmle!(Ψ, dataset, threshold=0.025, verbosity=0) - l, u = confint(OneSampleTTest(tmle(result))) - @test estimate(result.tmle) ≈ -0.185533 atol = 1e-6 + # TMLE + tmle_result = tmle(result) + @test estimate(tmle_result) ≈ -0.185533 atol = 1e-6 + l, u = confint(OneSampleTTest(tmle_result)) @test l ≈ -0.279246 atol = 1e-6 @test u ≈ -0.091821 atol = 1e-6 + @test OneSampleZTest(tmle_result) isa OneSampleZTest + # OSE + ose_result = ose(result) + @test estimate(ose_result) isa Float64 + @test OneSampleTTest(ose_result) isa OneSampleTTest + @test OneSampleZTest(ose_result) isa OneSampleZTest + # Naive + @test initial(result) isa Float64 + end end From 32467ae45f2cfcc9a6373e4599beb34a0f52a743 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 13 Aug 2023 15:30:59 +0100 Subject: [PATCH 045/151] up def of with_encoder --- src/treatment_transformer.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/treatment_transformer.jl b/src/treatment_transformer.jl index 8980647a..510a9f12 100644 --- a/src/treatment_transformer.jl +++ b/src/treatment_transformer.jl @@ -41,4 +41,4 @@ function MLJBase.transform(model::TreatmentTransformer, fitresult, Xnew) return merge(Xt, ordered_factors) end -with_encoder(model; encoder=encoder()) = TreatmentTransformer(;encoder=encoder) |> model \ No newline at end of file +with_encoder(model; encoder=encoder()) = Pipeline(TreatmentTransformer(;encoder=encoder), model) \ No newline at end of file From e0ed8f788347b4c339b361c732372ea388a3160a Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 14 Aug 2023 07:48:36 +0100 Subject: [PATCH 046/151] add data adaptive truncation --- docs/src/user_guide/estimation.md | 8 ++++---- src/TMLE.jl | 1 - src/estimation.jl | 31 ++++++++++++++++++++----------- src/utils.jl | 31 ++++++++++++++++++++++++------- test/non_regression_test.jl | 2 +- test/utils.jl | 7 +++++++ 6 files changed, 56 insertions(+), 24 deletions(-) diff --git a/docs/src/user_guide/estimation.md b/docs/src/user_guide/estimation.md index c8805c0b..3013a810 100644 --- a/docs/src/user_guide/estimation.md +++ b/docs/src/user_guide/estimation.md @@ -58,7 +58,7 @@ result₁, fluctuation_mach = tmle!(Ψ₁, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=1, force=false, - threshold=1e-8, + ps_lowerbound=1e-8, weighted_fluctuation=false ) nothing # hide @@ -92,7 +92,7 @@ result₂, fluctuation_mach = tmle!(Ψ₂, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=1, force=false, - threshold=1e-8, + ps_lowerbound=1e-8, weighted_fluctuation=false ) nothing # hide @@ -110,7 +110,7 @@ result₃, fluctuation_mach = tmle!(Ψ₃, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=1, force=false, - threshold=1e-8, + ps_lowerbound=1e-8, weighted_fluctuation=false ) nothing # hide @@ -124,7 +124,7 @@ result₄, fluctuation_mach = tmle!(Ψ₄, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=1, force=false, - threshold=1e-8, + ps_lowerbound=1e-8, weighted_fluctuation=false ) nothing # hide diff --git a/src/TMLE.jl b/src/TMLE.jl index f877a479..0acd4a20 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -36,7 +36,6 @@ export var, estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose export TreatmentTransformer, with_encoder export BackdoorAdjustment -export DAG, graphplot # ############################################################################# # INCLUDES diff --git a/src/estimation.jl b/src/estimation.jl index a7b37488..f1b285de 100644 --- a/src/estimation.jl +++ b/src/estimation.jl @@ -14,7 +14,7 @@ end adjustment_method=BackdoorAdjustment(), verbosity=1, force=false, - threshold=1e-8, + ps_lowerbound=1e-8, weighted_fluctuation=false ) @@ -27,14 +27,14 @@ Performs Targeted Minimum Loss Based Estimation of the target estimand. - adjustment_method: A confounding adjustment method. - verbosity: Level of logging. - force: To force refit of machines in the SCM . -- threshold: The balancing score will be bounded to respect this threshold. +- ps_lowerbound: The propensity score will be truncated to respect this lower bound. - weighted_fluctuation: To use a weighted fluctuation instead of the vanilla TMLE, can improve stability. """ function tmle!(Ψ::CMCompositeEstimand, dataset; adjustment_method=BackdoorAdjustment(), verbosity=1, force=false, - threshold=1e-8, + ps_lowerbound=1e-8, weighted_fluctuation=false ) # Check the estimand against the dataset @@ -47,10 +47,18 @@ function tmle!(Ψ::CMCompositeEstimand, dataset; force=force ) # TMLE step + ps_lowerbound = ps_lower_bound(Ψ, ps_lowerbound) verbosity >= 1 && @info "Performing TMLE..." - fluctuation_mach = tmle_step(Ψ, verbosity=verbosity, threshold=threshold, weighted_fluctuation=weighted_fluctuation) + fluctuation_mach = tmle_step(Ψ; + verbosity=verbosity, + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) # Estimation results after TMLE - IC, Ψ̂, ICᵢ, Ψ̂ᵢ = gradient_and_estimates(Ψ, fluctuation_mach, threshold=threshold, weighted_fluctuation=weighted_fluctuation) + IC, Ψ̂, ICᵢ, Ψ̂ᵢ = gradient_and_estimates(Ψ, fluctuation_mach, + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) tmle_result = TMLEstimate(Ψ̂, IC) one_step_result = OSEstimate(Ψ̂ᵢ + mean(ICᵢ), ICᵢ) @@ -58,7 +66,7 @@ function tmle!(Ψ::CMCompositeEstimand, dataset; TMLEResult(Ψ, tmle_result, one_step_result, Ψ̂ᵢ), fluctuation_mach end -function tmle_step(Ψ::CMCompositeEstimand; verbosity=1, threshold=1e-8, weighted_fluctuation=false) +function tmle_step(Ψ::CMCompositeEstimand; verbosity=1, ps_lowerbound=1e-8, weighted_fluctuation=false) outcome_equation = Ψ.scm[outcome(Ψ)] X, y = outcome_equation.mach.data outcome_mach = outcome_equation.mach @@ -69,7 +77,7 @@ function tmle_step(Ψ::CMCompositeEstimand; verbosity=1, threshold=1e-8, weighte W = confounders(X, Ψ) covariate, weights = TMLE.clever_covariate_and_weights( Ψ.scm, W, T, indicator_fns(Ψ); - threshold=threshold, + ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation ) # Fit fluctuation @@ -79,7 +87,7 @@ function tmle_step(Ψ::CMCompositeEstimand; verbosity=1, threshold=1e-8, weighte return mach end -function counterfactual_aggregates(Ψ::CMCompositeEstimand, fluctuation_mach; threshold=1e-8, weighted_fluctuation=false) +function counterfactual_aggregates(Ψ::CMCompositeEstimand, fluctuation_mach; ps_lowerbound=1e-8, weighted_fluctuation=false) outcome_eq = outcome_equation(Ψ) X = outcome_eq.mach.data[1] Ttemplate = treatments(X, Ψ) @@ -100,7 +108,8 @@ function counterfactual_aggregates(Ψ::CMCompositeEstimand, fluctuation_mach; th offset = compute_offset(ŷᵢ) covariate, _ = clever_covariate_and_weights( Ψ.scm, W, T_ct, indicators; - threshold=threshold, weighted_fluctuation=weighted_fluctuation + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation ) Xfluct_ct = fluctuation_input(covariate, offset) ŷ = predict(fluctuation_mach, Xfluct_ct) @@ -135,11 +144,11 @@ function gradients_Y_X(outcome_mach::Machine, fluctuation_mach::Machine) end -function gradient_and_estimates(Ψ::CMCompositeEstimand, fluctuation_mach::Machine; threshold=1e-8, weighted_fluctuation=false) +function gradient_and_estimates(Ψ::CMCompositeEstimand, fluctuation_mach::Machine; ps_lowerbound=1e-8, weighted_fluctuation=false) counterfactual_aggregate, counterfactual_aggregateᵢ = TMLE.counterfactual_aggregates( Ψ, fluctuation_mach; - threshold=threshold, + ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation ) Ψ̂, Ψ̂ᵢ = mean(counterfactual_aggregate), mean(counterfactual_aggregateᵢ) diff --git a/src/utils.jl b/src/utils.jl index 59293c0e..65eed255 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,15 +4,32 @@ selectcols(data, cols) = data |> TableOperations.select(cols...) |> Tables.columntable +data_adaptive_ps_lower_bound(n::Int; max_lb=0.1) = + min(5 / (√(n)*log(n/5)), max_lb) +""" + data_adaptive_ps_lower_bound(Ψ::CMCompositeEstimand) + +This startegy is from [this paper](https://academic.oup.com/aje/article/191/9/1640/6580570?login=false) +but the study does not show strictly better behaviour of the strategy so not a default for now. +""" +function data_adaptive_ps_lower_bound(Ψ::CMCompositeEstimand;max_lb=0.1) + n = nrows(outcome_equation(Ψ).mach.data[2]) + return data_adaptive_ps_lower_bound(n; max_lb=max_lb) +end + +ps_lower_bound(Ψ::CMCompositeEstimand, lower_bound::Nothing; max_lb=0.1) = data_adaptive_ps_lower_bound(Ψ; max_lb=max_lb) +ps_lower_bound(Ψ::CMCompositeEstimand, lower_bound; max_lb=0.1) = min(max_lb, lower_bound) + + function logit!(v) for i in eachindex(v) v[i] = logit(v[i]) end end -function plateau!(v::AbstractVector, threshold) +function plateau!(v::AbstractVector, ps_lowerbound) for i in eachindex(v) - v[i] = max(v[i], threshold) + v[i] = max(v[i], ps_lowerbound) end end @@ -143,19 +160,19 @@ compute_offset(ŷ::AbstractVector{<:Real}) = expected_value(ŷ) compute_offset(Ψ::CMCompositeEstimand) = compute_offset(MLJBase.predict(outcome_equation(Ψ).mach)) -function balancing_weights(scm::SCM, W, T; threshold=1e-8) +function balancing_weights(scm::SCM, W, T; ps_lowerbound=1e-8) density = ones(nrows(T)) for colname ∈ Tables.columnnames(T) mach = scm[colname].mach ŷ = MLJBase.predict(mach, W[colname]) density .*= pdf.(ŷ, Tables.getcolumn(T, colname)) end - plateau!(density, threshold) + plateau!(density, ps_lowerbound) return 1. ./ density end """ - clever_covariate_and_weights(jointT, W, G, indicator_fns; threshold=1e-8, weighted_fluctuation=false) + clever_covariate_and_weights(jointT, W, G, indicator_fns; ps_lowerbound=1e-8, weighted_fluctuation=false) Computes the clever covariate and weights that are used to fluctuate the initial Q. @@ -171,10 +188,10 @@ if `weighted_fluctuation = true`: where SpecialIndicator(t) is defined in `indicator_fns`. """ -function clever_covariate_and_weights(scm::SCM, W, T, indicator_fns; threshold=1e-8, weighted_fluctuation=false) +function clever_covariate_and_weights(scm::SCM, W, T, indicator_fns; ps_lowerbound=1e-8, weighted_fluctuation=false) # Compute the indicator values indic_vals = TMLE.indicator_values(indicator_fns, T) - weights = balancing_weights(scm, W, T, threshold=threshold) + weights = balancing_weights(scm, W, T, ps_lowerbound=ps_lowerbound) if weighted_fluctuation return indic_vals, weights end diff --git a/test/non_regression_test.jl b/test/non_regression_test.jl index 9814c140..99d5f201 100644 --- a/test/non_regression_test.jl +++ b/test/non_regression_test.jl @@ -23,7 +23,7 @@ using MLJGLMInterface ) Ψ = ATE(scm, outcome=:haz01, treatment=(parity01=(case=1, control=0),)) - result, fluctuation_mach = tmle!(Ψ, dataset, threshold=0.025, verbosity=0) + result, fluctuation_mach = tmle!(Ψ, dataset, ps_lowerbound=0.025, verbosity=0) # TMLE tmle_result = tmle(result) @test estimate(tmle_result) ≈ -0.185533 atol = 1e-6 diff --git a/test/utils.jl b/test/utils.jl index 22e46056..a41c1635 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -175,6 +175,13 @@ end W₃ = rand(7) ) fit!(Ψ, dataset, verbosity=0) + + @test TMLE.ps_lower_bound(Ψ, nothing) == 0.1 + @test TMLE.data_adaptive_ps_lower_bound(Ψ) == 0.1 + @test TMLE.data_adaptive_ps_lower_bound(1000) == 0.02984228238321508 + @test TMLE.ps_lower_bound(Ψ, 1e-8) == 1.0e-8 + @test TMLE.ps_lower_bound(Ψ, 1) == 0.1 + indicator_fns = TMLE.indicator_fns(Ψ) T = TMLE.treatments(dataset, Ψ) @test T == (T=dataset.T,) From 564ecc5dec9ac24c6b28b6e383b87d98f3b65f7d Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 14 Aug 2023 07:49:24 +0100 Subject: [PATCH 047/151] rename plateau to truncate --- src/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 65eed255..b73c9d2f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -27,7 +27,7 @@ function logit!(v) end end -function plateau!(v::AbstractVector, ps_lowerbound) +function truncate!(v::AbstractVector, ps_lowerbound) for i in eachindex(v) v[i] = max(v[i], ps_lowerbound) end @@ -167,7 +167,7 @@ function balancing_weights(scm::SCM, W, T; ps_lowerbound=1e-8) ŷ = MLJBase.predict(mach, W[colname]) density .*= pdf.(ŷ, Tables.getcolumn(T, colname)) end - plateau!(density, ps_lowerbound) + truncate!(density, ps_lowerbound) return 1. ./ density end From 64c37d631fff1c8900532a61a17522ff8c8223b7 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 16 Aug 2023 17:46:56 +0200 Subject: [PATCH 048/151] add dodgy workload for now --- src/TMLE.jl | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index 0acd4a20..5c313327 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -81,13 +81,20 @@ function run_precompile_workload() setmodel!(scm.T₁, LinearBinaryClassifier()) ## Implicit through estimand for estimand_type in [CM, ATE, IATE] - estimand_type(outcome=:Y₁, treatment=(T₁=true,), confounders=[:W₁, :W₂]) + estimand_type( + outcome=:Y₁, + treatment=(T₁=true,), + confounders=[:W₁, :W₂], + outcome_model=LinearRegressor() + ) end ## Complete + # Not using the `with_encoder` for now because it crashes precompilation + # and fit can still happen somehow. scm = SCM( - SE(:Y₁, [:T₁, :W₁, :W₂], model=with_encoder(LinearRegressor())), + SE(:Y₁, [:T₁, :W₁, :W₂], model=LinearRegressor()), SE(:T₁, [:W₁, :W₂],model=LinearBinaryClassifier()), - SE(:Y₂, [:T₁, :T₂, :W₁, :W₂, :C₁], model=with_encoder(LinearBinaryClassifier())), + SE(:Y₂, [:T₁, :T₂, :W₁, :W₂, :C₁], model=LinearBinaryClassifier()), SE(:T₂, [:W₁, :W₂],model=LinearBinaryClassifier()), ) @@ -97,31 +104,32 @@ function run_precompile_workload() outcome =:Y₁, treatment=(T₁=true,), ) - result₁, fluctuation = tmle!(Ψ₁, dataset) + result₁, fluctuation = tmle!(Ψ₁, dataset, verbosity=0) Ψ₂ = ATE( scm, outcome=:Y₂, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)) ) - result₂, fluctuation = tmle!(Ψ₂, dataset) + result₂, fluctuation = tmle!(Ψ₂, dataset, verbosity=0) Ψ₃ = IATE( scm, outcome=:Y₂, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)) ) - result₃, fluctuation = tmle!(Ψ₃, dataset) + result₃, fluctuation = tmle!(Ψ₃, dataset, verbosity=0) # Composition composed_result = compose((x,y) -> x - y, tmle(result₂), tmle(result₁)) # Results manipulation - initial(result) + initial(result₃) OneSampleTTest(tmle(result₃)) OneSampleZTest(tmle(result₃)) OneSampleTTest(ose(result₃)) OneSampleZTest(ose(result₃)) + OneSampleTTest(composed_result) OneSampleZTest(composed_result) end @@ -129,6 +137,6 @@ function run_precompile_workload() end -# run_precompile_workload() +run_precompile_workload() end From 654a21f00aec104654a71d27526e5fde36e78e69 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 20 Aug 2023 11:30:28 +0100 Subject: [PATCH 049/151] add fluctuation model --- src/TMLE.jl | 6 +- src/estimands.jl | 10 +- src/estimate.jl | 85 +++++++---------- src/estimation.jl | 135 +++++++-------------------- src/fluctuation.jl | 86 +++++++++++++++++ src/gradient.jl | 57 ++++++++++++ src/utils.jl | 60 ++++-------- test/composition.jl | 37 ++++---- test/double_robustness_ate.jl | 16 ++-- test/double_robustness_iate.jl | 20 ++-- test/estimands.jl | 3 - test/estimation.jl | 14 +-- test/fluctuation.jl | 118 ++++++++++++++++++++++++ test/gradient.jl | 67 ++++++++++++++ test/helper_fns.jl | 16 ++-- test/non_regression_test.jl | 19 ++-- test/runtests.jl | 2 + test/utils.jl | 164 ++++++++++++++------------------- 18 files changed, 546 insertions(+), 369 deletions(-) create mode 100644 src/fluctuation.jl create mode 100644 src/gradient.jl create mode 100644 test/fluctuation.jl create mode 100644 test/gradient.jl diff --git a/src/TMLE.jl b/src/TMLE.jl index 5c313327..1bbea29c 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -31,7 +31,7 @@ export AverageTreatmentEffect, ATE export InteractionAverageTreatmentEffect, IATE export AVAILABLE_ESTIMANDS export fit!, optimize_ordering, optimize_ordering! -export tmle!, tmle, ose, initial +export tmle!, ose!, initial export var, estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose export TreatmentTransformer, with_encoder @@ -43,7 +43,9 @@ export BackdoorAdjustment include("scm.jl") include("estimands.jl") +include("fluctuation.jl") include("adjustment.jl") +include("gradient.jl") include("utils.jl") include("estimation.jl") include("estimate.jl") @@ -137,6 +139,6 @@ function run_precompile_workload() end -run_precompile_workload() +# run_precompile_workload() end diff --git a/src/estimands.jl b/src/estimands.jl index 0bb965ac..b88d4efe 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -188,6 +188,8 @@ end outcome_equation(Ψ::CMCompositeEstimand) = Ψ.scm[outcome(Ψ)] +getQ(Ψ::CMCompositeEstimand) = outcome_equation(Ψ).mach + treatments(Ψ::CMCompositeEstimand) = collect(keys(Ψ.treatment)) treatments(dataset, Ψ::CMCompositeEstimand) = selectcols(dataset, treatments(Ψ)) @@ -196,14 +198,6 @@ outcome(dataset, Ψ::CMCompositeEstimand) = Tables.getcolumn(dataset, outcome(Ψ confounders(dataset, Ψ) = (;(T => selectcols(dataset, keys(Ψ.scm[T].mach.data[1])) for T in treatments(Ψ))...) -F_model(::Type{<:AbstractVector{<:MLJBase.Continuous}}) = - LinearRegressor(fit_intercept=false, offsetcol = :offset) - -F_model(::Type{<:AbstractVector{<:Finite}}) = - LinearBinaryClassifier(fit_intercept=false, offsetcol = :offset) - -F_model(t::Type{Any}) = throw(ArgumentError("Cannot proceed with Q model with target_scitype $t")) - function param_key(Ψ::CMCompositeEstimand) return ( join(treatments(Ψ), "_"), diff --git a/src/estimate.jl b/src/estimate.jl index 1ebd3adf..abfb7996 100644 --- a/src/estimate.jl +++ b/src/estimate.jl @@ -1,29 +1,24 @@ -abstract type AsymptoticallyLinearEstimate end +abstract type EICEstimate end -struct TMLEstimate{T<:AbstractFloat} <: AsymptoticallyLinearEstimate +struct TMLEstimate{T<:AbstractFloat} <: EICEstimate + Ψ̂⁰::T Ψ̂::T IC::Vector{T} end -struct OSEstimate{T<:AbstractFloat} <: AsymptoticallyLinearEstimate +struct OSEstimate{T<:AbstractFloat} <: EICEstimate + Ψ̂⁰::T Ψ̂::T IC::Vector{T} end -struct ComposedEstimate{T<:AbstractFloat} <: AsymptoticallyLinearEstimate +struct ComposedEstimate{T<:AbstractFloat} Ψ̂::Array{T} σ̂::Matrix{T} n::Int end -struct TMLEResult{P <: Estimand, T<: AbstractFloat} - estimand::P - tmle::TMLEstimate{T} - onestep::OSEstimate{T} - initial::T -end - function Base.show(io::IO, ::MIME"text/plain", est::ComposedEstimate) if length(est.σ̂) !== 1 println(io, string("Estimate: ", estimate(est), "\nVariance: \n", var(est))) @@ -35,80 +30,70 @@ function Base.show(io::IO, ::MIME"text/plain", est::ComposedEstimate) end end - - -function Base.show(io::IO, ::MIME"text/plain", est::AsymptoticallyLinearEstimate) +function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate) testresult = OneSampleTTest(est) data = [estimate(est) confint(testresult) pvalue(testresult);] pretty_table(io, data;header=["Estimate", "95% Confidence Interval", "P-value"]) end -function Base.show(io::IO, ::MIME"text/plain", r::TMLEResult) - tmletest = OneSampleTTest(r.tmle) - onesteptest = OneSampleTTest(r.onestep) - data = [ - :TMLE estimate(tmle(r)) confint(tmletest) pvalue(tmletest); - :OSE estimate(ose(r)) confint(onesteptest) pvalue(onesteptest); - :Naive r.initial nothing nothing - ] - pretty_table(io, data;header=["Estimator", "Estimate", "95% Confidence Interval", "P-value"]) -end - -tmle(result::TMLEResult) = result.tmle -ose(result::TMLEResult) = result.onestep -initial(result::TMLEResult) = result.initial +initial(est::EICEstimate) = est.Ψ̂⁰ """ - Distributions.estimate(r::AsymptoticallyLinearEstimate) + Distributions.estimate(r::EICEstimate) Retrieves the final estimate: after the TMLE step. """ -Distributions.estimate(r::AsymptoticallyLinearEstimate) = r.Ψ̂ +Distributions.estimate(est::EICEstimate) = est.Ψ̂ """ Distributions.estimate(r::ComposedEstimate) Retrieves the final estimate: after the TMLE step. """ -Distributions.estimate(r::ComposedEstimate) = length(r.Ψ̂) == 1 ? r.Ψ̂[1] : r.Ψ̂ +Distributions.estimate(est::ComposedEstimate) = + length(est.Ψ̂) == 1 ? est.Ψ̂[1] : est.Ψ̂ """ - var(r::AsymptoticallyLinearEstimate) + var(r::EICEstimate) Computes the estimated variance associated with the estimate. """ -Statistics.var(r::AsymptoticallyLinearEstimate) = var(r.IC)/size(r.IC, 1) +Statistics.var(est::EICEstimate) = + var(est.IC)/size(est.IC, 1) """ var(r::ComposedEstimate) Computes the estimated variance associated with the estimate. """ -Statistics.var(r::ComposedEstimate) = length(r.σ̂) == 1 ? r.σ̂[1] / r.n : r.σ̂ ./ r.n +Statistics.var(est::ComposedEstimate) = + length(est.σ̂) == 1 ? est.σ̂[1] / est.n : est.σ̂ ./ est.n """ - OneSampleZTest(r::AsymptoticallyLinearEstimate, Ψ₀=0) + OneSampleZTest(r::EICEstimate, Ψ₀=0) -Performs a Z test on the AsymptoticallyLinearEstimate. +Performs a Z test on the EICEstimate. """ -HypothesisTests.OneSampleZTest(r::AsymptoticallyLinearEstimate, Ψ₀=0) = OneSampleZTest(estimate(r), std(r.IC), size(r.IC, 1), Ψ₀) +HypothesisTests.OneSampleZTest(est::EICEstimate, Ψ₀=0) = + OneSampleZTest(estimate(est), std(est.IC), size(est.IC, 1), Ψ₀) """ - OneSampleTTest(r::AsymptoticallyLinearEstimate, Ψ₀=0) + OneSampleTTest(r::EICEstimate, Ψ₀=0) -Performs a T test on the AsymptoticallyLinearEstimate. +Performs a T test on the EICEstimate. """ -HypothesisTests.OneSampleTTest(r::AsymptoticallyLinearEstimate, Ψ₀=0) = OneSampleTTest(estimate(r), std(r.IC), size(r.IC, 1), Ψ₀) +HypothesisTests.OneSampleTTest(est::EICEstimate, Ψ₀=0) = + OneSampleTTest(estimate(est), std(est.IC), size(est.IC, 1), Ψ₀) """ OneSampleTTest(r::ComposedEstimate, Ψ₀=0) Performs a T test on the ComposedEstimate. """ -function HypothesisTests.OneSampleTTest(r::ComposedEstimate, Ψ₀=0) - @assert length(r.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." - return OneSampleTTest(estimate(r), sqrt(r.σ̂[1]), r.n, Ψ₀) +function HypothesisTests.OneSampleTTest(est::ComposedEstimate, Ψ₀=0) + @assert length(est.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." + return OneSampleTTest(estimate(est), sqrt(est.σ̂[1]), est.n, Ψ₀) end """ @@ -116,13 +101,13 @@ end Performs a T test on the ComposedEstimate. """ -function HypothesisTests.OneSampleZTest(r::ComposedEstimate, Ψ₀=0) - @assert length(r.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." - return OneSampleZTest(estimate(r), sqrt(r.σ̂[1]), r.n, Ψ₀) +function HypothesisTests.OneSampleZTest(est::ComposedEstimate, Ψ₀=0) + @assert length(est.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." + return OneSampleZTest(estimate(est), sqrt(est.σ̂[1]), est.n, Ψ₀) end """ - compose(f, estimation_results::Vararg{AsymptoticallyLinearEstimate, N}) where N + compose(f, estimation_results::Vararg{EICEstimate, N}) where N Provides an estimator of f(estimation_results...). @@ -155,7 +140,7 @@ Hence, the only thing we need to do is: # Arguments - f: An array-input differentiable map. -- estimation_results: 1 or more `AsymptoticallyLinearEstimate` structs. +- estimation_results: 1 or more `EICEstimate` structs. # Examples @@ -166,7 +151,7 @@ f(x, y) = [x^2 - y, y - 3x] compose(f, res₁, res₂) ``` """ -function compose(f, estimators::Vararg{AsymptoticallyLinearEstimate, N}; backend=AD.ZygoteBackend()) where N +function compose(f, estimators::Vararg{EICEstimate, N}; backend=AD.ZygoteBackend()) where N Σ = cov(estimators...) estimates = [estimate(r) for r in estimators] f₀, Js = AD.value_and_jacobian(backend, f, estimates...) @@ -176,7 +161,7 @@ function compose(f, estimators::Vararg{AsymptoticallyLinearEstimate, N}; backend return ComposedEstimate(collect(f₀), σ₀, n) end -function Statistics.cov(estimators::Vararg{AsymptoticallyLinearEstimate, N}) where N +function Statistics.cov(estimators::Vararg{EICEstimate, N}) where N X = hcat([r.IC for r in estimators]...) return Statistics.cov(X, dims=1, corrected=true) end diff --git a/src/estimation.jl b/src/estimation.jl index f1b285de..4713a9d0 100644 --- a/src/estimation.jl +++ b/src/estimation.jl @@ -46,115 +46,46 @@ function tmle!(Ψ::CMCompositeEstimand, dataset; verbosity=verbosity, force=force ) - # TMLE step + # Get propensity score truncation threshold ps_lowerbound = ps_lower_bound(Ψ, ps_lowerbound) + # Fit Fluctuation verbosity >= 1 && @info "Performing TMLE..." - fluctuation_mach = tmle_step(Ψ; - verbosity=verbosity, - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation + Q⁰ = getQ(Ψ) + X, y = Q⁰.data + Q = machine( + Fluctuation(Ψ, 0.1, ps_lowerbound, weighted_fluctuation), + X, + y ) + fit!(Q, verbosity=verbosity-1) # Estimation results after TMLE - IC, Ψ̂, ICᵢ, Ψ̂ᵢ = gradient_and_estimates(Ψ, fluctuation_mach, - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation - ) - tmle_result = TMLEstimate(Ψ̂, IC) - one_step_result = OSEstimate(Ψ̂ᵢ + mean(ICᵢ), ICᵢ) - + Ψ̂⁰ = mean(counterfactual_aggregate(Ψ, Q⁰)) + IC, Ψ̂ = gradient_and_estimate(Ψ, Q; ps_lowerbound=ps_lowerbound) verbosity >= 1 && @info "Done." - TMLEResult(Ψ, tmle_result, one_step_result, Ψ̂ᵢ), fluctuation_mach -end - -function tmle_step(Ψ::CMCompositeEstimand; verbosity=1, ps_lowerbound=1e-8, weighted_fluctuation=false) - outcome_equation = Ψ.scm[outcome(Ψ)] - X, y = outcome_equation.mach.data - outcome_mach = outcome_equation.mach - # Compute offset - offset = TMLE.compute_offset(MLJBase.predict(outcome_mach)) - # Compute clever covariate and weights for weighted fluctuation mode - T = treatments(X, Ψ) - W = confounders(X, Ψ) - covariate, weights = TMLE.clever_covariate_and_weights( - Ψ.scm, W, T, indicator_fns(Ψ); - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation - ) - # Fit fluctuation - Xfluct = TMLE.fluctuation_input(covariate, offset) - mach = machine(F_model(scitype(y)), Xfluct, y, weights, cache=true) - MLJBase.fit!(mach, verbosity=verbosity-1) - return mach -end - -function counterfactual_aggregates(Ψ::CMCompositeEstimand, fluctuation_mach; ps_lowerbound=1e-8, weighted_fluctuation=false) - outcome_eq = outcome_equation(Ψ) - X = outcome_eq.mach.data[1] - Ttemplate = treatments(X, Ψ) - n = nrows(Ttemplate) - counterfactual_aggregateᵢ = zeros(n) - counterfactual_aggregate = zeros(n) - # Loop over Treatment settings - indicators = indicator_fns(Ψ) - for (vals, sign) in indicators - # Counterfactual dataset for a given treatment setting - T_ct = TMLE.counterfactualTreatment(vals, Ttemplate) - X_ct = selectcols(merge(X, T_ct), keys(X)) - W = confounders(X_ct, Ψ) - # Counterfactual predictions with the initial Q - ŷᵢ = MLJBase.predict(outcome_eq.mach, X_ct) - counterfactual_aggregateᵢ .+= sign .* expected_value(ŷᵢ) - # Counterfactual predictions with F - offset = compute_offset(ŷᵢ) - covariate, _ = clever_covariate_and_weights( - Ψ.scm, W, T_ct, indicators; - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation - ) - Xfluct_ct = fluctuation_input(covariate, offset) - ŷ = predict(fluctuation_mach, Xfluct_ct) - counterfactual_aggregate .+= sign .* expected_value(ŷ) - end - return counterfactual_aggregate, counterfactual_aggregateᵢ + return TMLEstimate(Ψ̂⁰, Ψ̂, IC), Q end -""" - gradient_W(counterfactual_aggregate, estimate) - -∇_W = counterfactual_aggregate - Ψ -""" -gradient_W(counterfactual_aggregate, estimate) = - counterfactual_aggregate .- estimate - - -""" - gradient_Y_X(cache) - -∇_YX(w, t, c) = covariate(w, t) ̇ (y - E[Y|w, t, c]) - -This part of the gradient is evaluated on the original dataset. All quantities have been precomputed and cached. -""" -function gradients_Y_X(outcome_mach::Machine, fluctuation_mach::Machine) - X, y, w = fluctuation_mach.data - y = float(y) - covariate = X.covariate .* w - gradient_Y_Xᵢ = covariate .* (y .- expected_value(MLJBase.predict(outcome_mach))) - gradient_Y_X_fluct = covariate .* (y .- expected_value(MLJBase.predict(fluctuation_mach))) - return gradient_Y_Xᵢ, gradient_Y_X_fluct -end - - -function gradient_and_estimates(Ψ::CMCompositeEstimand, fluctuation_mach::Machine; ps_lowerbound=1e-8, weighted_fluctuation=false) - counterfactual_aggregate, counterfactual_aggregateᵢ = TMLE.counterfactual_aggregates( - Ψ, - fluctuation_mach; - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation +function ose!(Ψ::CMCompositeEstimand, dataset; + adjustment_method=BackdoorAdjustment(), + verbosity=1, + force=false, + ps_lowerbound=1e-8) + # Check the estimand against the dataset + check_treatment_levels(Ψ, dataset) + # Initial fit of the SCM's equations + verbosity >= 1 && @info "Fitting the required equations..." + fit!(Ψ, dataset; + adjustment_method=adjustment_method, + verbosity=verbosity, + force=force ) - Ψ̂, Ψ̂ᵢ = mean(counterfactual_aggregate), mean(counterfactual_aggregateᵢ) - gradient_Y_Xᵢ, gradient_Y_X_fluct = gradients_Y_X(outcome_equation(Ψ).mach, fluctuation_mach) - IC = gradient_Y_X_fluct .+ gradient_W(counterfactual_aggregate, Ψ̂) - ICᵢ = gradient_Y_Xᵢ .+ gradient_W(counterfactual_aggregateᵢ, Ψ̂ᵢ) - return IC, Ψ̂, ICᵢ, Ψ̂ᵢ + # Get propensity score truncation threshold + ps_lowerbound = ps_lower_bound(Ψ, ps_lowerbound) + # Retrieve initial fit + Q = getQ(Ψ) + # Gradient and estimate + IC, Ψ̂ = gradient_and_estimate(Ψ, Q; ps_lowerbound=ps_lowerbound) + verbosity >= 1 && @info "Done." + return OSEstimate(Ψ̂, Ψ̂ + mean(IC), IC), Q end diff --git a/src/fluctuation.jl b/src/fluctuation.jl new file mode 100644 index 00000000..4cada85e --- /dev/null +++ b/src/fluctuation.jl @@ -0,0 +1,86 @@ + +mutable struct ContinuousFluctuation <: MLJBase.Probabilistic + Ψ::CMCompositeEstimand + tol::Float64 + ps_lowerbound::Float64 + weighted::Bool +end + +mutable struct BinaryFluctuation <: MLJBase.Probabilistic + Ψ::CMCompositeEstimand + tol::Float64 + ps_lowerbound::Float64 + weighted::Bool +end + +FluctuationModel = Union{ContinuousFluctuation, BinaryFluctuation} + +function Fluctuation(Ψ, tol, ps_lowerbound, weighted) + outcome_scitype = scitype(getQ(Ψ).data[2]) + if outcome_scitype <: AbstractVector{<:MLJBase.Continuous} + return ContinuousFluctuation(Ψ, tol, ps_lowerbound, weighted) + elseif outcome_scitype <: AbstractVector{<:Finite} + return BinaryFluctuation(Ψ, tol, ps_lowerbound, weighted) + else + throw(ArgumentError("Cannot proceed with outcome with target_scitype: $outcome_scitype")) + end +end + +epsilon(Qstar) = fitted_params(fitted_params(Qstar).fitresult).coef[1] + +one_dimensional_path(model::ContinuousFluctuation) = LinearRegressor(fit_intercept=false, offsetcol = :offset) +one_dimensional_path(model::BinaryFluctuation) = LinearBinaryClassifier(fit_intercept=false, offsetcol = :offset) + +fluctuation_input(covariate::AbstractVector{T}, offset::AbstractVector{T}) where T = (covariate=covariate, offset=offset) + +""" + +The GLM models require inputs of the same type +""" +fluctuation_input(covariate::AbstractVector{T1}, offset::AbstractVector{T2}) where {T1, T2} = + (covariate=covariate, offset=convert(Vector{T1}, offset)) + +function clever_covariate_offset_and_weights(Ψ, Q, X; + ps_lowerbound=1e-8, + weighted_fluctuation=false + ) + offset = TMLE.compute_offset(MLJBase.predict(Q, X)) + covariate, weights = TMLE.clever_covariate_and_weights( + Ψ, X; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) + Xfluct = fluctuation_input(covariate, offset) + return Xfluct, weights +end + +function MLJBase.fit(model::FluctuationModel, verbosity, X, y) + Ψ = model.Ψ + Q = getQ(Ψ) + clever_covariate_and_offset, weights = + clever_covariate_offset_and_weights( + Ψ, Q, X; + ps_lowerbound=model.ps_lowerbound, + weighted_fluctuation=model.weighted + ) + mach = machine( + one_dimensional_path(model), + clever_covariate_and_offset, + y, + weights, + ) + MLJBase.fit!(mach, verbosity=verbosity) + return mach, nothing, nothing +end + +function MLJBase.predict(model::FluctuationModel, one_dimensional_path, X) + Ψ = model.Ψ + Q = getQ(Ψ) + clever_covariate_and_offset, weights = + clever_covariate_offset_and_weights( + Ψ, Q, X; + ps_lowerbound=model.ps_lowerbound, + weighted_fluctuation=model.weighted + ) + return MLJBase.predict(one_dimensional_path, clever_covariate_and_offset) +end \ No newline at end of file diff --git a/src/gradient.jl b/src/gradient.jl new file mode 100644 index 00000000..c141e6e7 --- /dev/null +++ b/src/gradient.jl @@ -0,0 +1,57 @@ +""" + counterfactual_aggregate(Ψ::CMCompositeEstimand, Q::Machine) + +This is a counterfactual aggregate that depends on the parameter of interest. + +For the ATE with binary treatment, confounded by W it is equal to: + +``ctf_agg = Q(1, W) - Q(0, W)`` + +""" +function counterfactual_aggregate(Ψ::CMCompositeEstimand, Q::Machine) + X = getQ(Ψ).data[1] + Ttemplate = treatments(X, Ψ) + n = nrows(Ttemplate) + ctf_agg = zeros(n) + # Loop over Treatment settings + for (vals, sign) in indicator_fns(Ψ) + # Counterfactual dataset for a given treatment setting + T_ct = counterfactualTreatment(vals, Ttemplate) + X_ct = selectcols(merge(X, T_ct), keys(X)) + # Counterfactual predictions with F + ŷ = predict(Q, X_ct) + ctf_agg .+= sign .* expected_value(ŷ) + end + return ctf_agg +end + +""" + ∇W(ctf_agg, Ψ̂) + +∇_W = ctf_agg - Ψ̂ +""" +∇W(ctf_agg, Ψ̂) = ctf_agg .- Ψ̂ + + +""" + gradient_Y_X(cache) + +∇_YX(w, t, c) = covariate(w, t) ̇ (y - E[Y|w, t, c]) + +This part of the gradient is evaluated on the original dataset. All quantities have been precomputed and cached. +""" +function ∇YX(Ψ::CMCompositeEstimand, Q::Machine; ps_lowerbound=1e-8) + X, y = getQ(Ψ).data + H, weights = clever_covariate_and_weights(Ψ, X; ps_lowerbound=ps_lowerbound) + y = float(y) + gradient_Y_X_fluct = H .* weights .* (y .- expected_value(MLJBase.predict(Q))) + return gradient_Y_X_fluct +end + + +function gradient_and_estimate(Ψ::CMCompositeEstimand, Q::Machine; ps_lowerbound=1e-8) + ctf_agg = counterfactual_aggregate(Ψ, Q) + Ψ̂ = mean(ctf_agg) + IC = ∇YX(Ψ, Q; ps_lowerbound = ps_lowerbound) .+ ∇W(ctf_agg, Ψ̂) + return IC, Ψ̂ +end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index b73c9d2f..9cb1e3ec 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,6 +6,7 @@ selectcols(data, cols) = data |> TableOperations.select(cols...) |> Tables.colum data_adaptive_ps_lower_bound(n::Int; max_lb=0.1) = min(5 / (√(n)*log(n/5)), max_lb) + """ data_adaptive_ps_lower_bound(Ψ::CMCompositeEstimand) @@ -13,7 +14,7 @@ This startegy is from [this paper](https://academic.oup.com/aje/article/191/9/16 but the study does not show strictly better behaviour of the strategy so not a default for now. """ function data_adaptive_ps_lower_bound(Ψ::CMCompositeEstimand;max_lb=0.1) - n = nrows(outcome_equation(Ψ).mach.data[2]) + n = nrows(getQ(Ψ).data[2]) return data_adaptive_ps_lower_bound(n; max_lb=max_lb) end @@ -53,8 +54,6 @@ end ncases(value, Ψ::Estimand) = sum(value[i] == Ψ.treatment[i].case for i in eachindex(value)) -""" -""" function indicator_fns(Ψ::IATE) N = length(treatments(Ψ)) key_vals = Pair[] @@ -84,10 +83,6 @@ function indicator_values(indicators, T) return indic end -############################################################################### -## Estimand Validation -############################################################################### - NotIdentifiedError(reasons) = ArgumentError(string( "The estimand is not identified for the following reasons: \n\t- ", join(reasons, "\n\t- "))) @@ -141,14 +136,18 @@ function check_treatment_levels(Ψ::CMCompositeEstimand, dataset) end end -############################################################################### -## Offset & Covariate -############################################################################### - expected_value(ŷ::UnivariateFiniteVector{Multiclass{2}}) = pdf.(ŷ, levels(first(ŷ))[2]) expected_value(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = mean.(ŷ) expected_value(ŷ::AbstractVector{<:Real}) = ŷ +function counterfactualTreatment(vals, T) + Tnames = Tables.columnnames(T) + n = nrows(T) + NamedTuple{Tnames}( + [categorical(repeat([vals[i]], n), levels=levels(Tables.getcolumn(T, name)), ordered=isordered(Tables.getcolumn(T, name))) + for (i, name) in enumerate(Tnames)]) +end + function compute_offset(ŷ::UnivariateFiniteVector{Multiclass{2}}) μy = expected_value(ŷ) logit!(μy) @@ -158,12 +157,12 @@ compute_offset(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = exp compute_offset(ŷ::AbstractVector{<:Real}) = expected_value(ŷ) compute_offset(Ψ::CMCompositeEstimand) = - compute_offset(MLJBase.predict(outcome_equation(Ψ).mach)) + compute_offset(MLJBase.predict(getQ(Ψ))) -function balancing_weights(scm::SCM, W, T; ps_lowerbound=1e-8) +function balancing_weights(Ψ::CMCompositeEstimand, W, T; ps_lowerbound=1e-8) density = ones(nrows(T)) for colname ∈ Tables.columnnames(T) - mach = scm[colname].mach + mach = Ψ.scm[colname].mach ŷ = MLJBase.predict(mach, W[colname]) density .*= pdf.(ŷ, Tables.getcolumn(T, colname)) end @@ -188,10 +187,12 @@ if `weighted_fluctuation = true`: where SpecialIndicator(t) is defined in `indicator_fns`. """ -function clever_covariate_and_weights(scm::SCM, W, T, indicator_fns; ps_lowerbound=1e-8, weighted_fluctuation=false) +function clever_covariate_and_weights(Ψ::CMCompositeEstimand, X; ps_lowerbound=1e-8, weighted_fluctuation=false) # Compute the indicator values - indic_vals = TMLE.indicator_values(indicator_fns, T) - weights = balancing_weights(scm, W, T, ps_lowerbound=ps_lowerbound) + T = treatments(X, Ψ) + W = confounders(X, Ψ) + indic_vals = indicator_values(indicator_fns(Ψ), T) + weights = balancing_weights(Ψ, W, T, ps_lowerbound=ps_lowerbound) if weighted_fluctuation return indic_vals, weights end @@ -199,28 +200,3 @@ function clever_covariate_and_weights(scm::SCM, W, T, indicator_fns; ps_lowerbou indic_vals .*= weights return indic_vals, ones(size(weights, 1)) end - -############################################################################### -## Fluctuation -############################################################################### - -fluctuation_input(covariate::AbstractVector{T}, offset::AbstractVector{T}) where T = (covariate=covariate, offset=offset) - -""" - -The GLM models require inputs of the same type -""" -fluctuation_input(covariate::AbstractVector{T1}, offset::AbstractVector{T2}) where {T1, T2} = - (covariate=covariate, offset=convert(Vector{T1}, offset)) - - -function counterfactualTreatment(vals, T) - Tnames = Tables.columnnames(T) - n = nrows(T) - NamedTuple{Tnames}( - [categorical(repeat([vals[i]], n), levels=levels(Tables.getcolumn(T, name)), ordered=isordered(Tables.getcolumn(T, name))) - for (i, name) in enumerate(Tnames)]) -end - - - diff --git a/test/composition.jl b/test/composition.jl index eb9a99e9..fa7575de 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -25,8 +25,8 @@ end @testset "Test cov" begin n = 10 X = rand(n, 2) - ER₁ = TMLE.TMLEstimate(1., X[:, 1]) - ER₂ = TMLE.TMLEstimate(0., X[:, 2]) + ER₁ = TMLE.TMLEstimate(0., 1., X[:, 1]) + ER₂ = TMLE.TMLEstimate(0., 0., X[:, 2]) Σ = cov(ER₁, ER₂) @test size(Σ) == (2, 2) @test Σ == cov(X) @@ -40,18 +40,20 @@ end outcome = :Y, treatment = (T=1,) ) - CM_result₁, _ = tmle!(CM₁, dataset, verbosity=0) + CM_tmle_result₁, _ = tmle!(CM₁, dataset, verbosity=0) + CM_ose_result₁, _ = ose!(CM₁, dataset, verbosity=0) # Conditional Mean T = 0 CM₀ = CM( scm, outcome = :Y, treatment = (T=0,) ) - CM_result₀, _ = tmle!(CM₀, dataset, verbosity=0) + CM_tmle_result₀, _ = tmle!(CM₀, dataset, verbosity=0) + CM_ose_result₀, _ = ose!(CM₀, dataset, verbosity=0) # Composition of TMLE - CM_result_composed_tmle = compose(-, tmle(CM_result₁), tmle(CM_result₀)); - CM_result_composed_ose = compose(-, ose(CM_result₁), ose(CM_result₀)); + CM_result_composed_tmle = compose(-, CM_tmle_result₁, CM_tmle_result₀); + CM_result_composed_ose = compose(-, CM_ose_result₁, CM_ose_result₀); # Via ATE ATE₁₀ = ATE( @@ -59,32 +61,33 @@ end outcome = :Y, treatment = (T=(case=1, control=0),), ) - ATE_result₁₀, _ = tmle!(ATE₁₀, dataset, verbosity=0) # Check composed TMLE - @test estimate(tmle(ATE_result₁₀)) ≈ estimate(CM_result_composed_tmle) atol = 1e-7 + ATE_tmle_result₁₀, _ = tmle!(ATE₁₀, dataset, verbosity=0) + @test estimate(ATE_tmle_result₁₀) ≈ estimate(CM_result_composed_tmle) atol = 1e-7 # T Test composed_confint = collect(confint(OneSampleTTest(CM_result_composed_tmle))) - tmle_confint = collect(confint(OneSampleTTest(tmle(ATE_result₁₀)))) + tmle_confint = collect(confint(OneSampleTTest(ATE_tmle_result₁₀))) @test tmle_confint ≈ composed_confint atol=1e-4 # Z Test composed_confint = collect(confint(OneSampleZTest(CM_result_composed_tmle))) - tmle_confint = collect(confint(OneSampleZTest(tmle(ATE_result₁₀)))) + tmle_confint = collect(confint(OneSampleZTest(ATE_tmle_result₁₀))) @test tmle_confint ≈ composed_confint atol=1e-4 # Variance - @test var(tmle(ATE_result₁₀)) ≈ var(CM_result_composed_tmle) atol = 1e-3 + @test var(ATE_tmle_result₁₀) ≈ var(CM_result_composed_tmle) atol = 1e-3 # Check composed OSE - @test estimate(ose(ATE_result₁₀)) ≈ estimate(CM_result_composed_ose) atol = 1e-7 + ATE_ose_result₁₀, _ = ose!(ATE₁₀, dataset, verbosity=0) + @test estimate(ATE_ose_result₁₀) ≈ estimate(CM_result_composed_ose) atol = 1e-7 # T Test composed_confint = collect(confint(OneSampleTTest(CM_result_composed_ose))) - ose_confint = collect(confint(OneSampleTTest(ose(ATE_result₁₀)))) + ose_confint = collect(confint(OneSampleTTest(ATE_ose_result₁₀))) @test ose_confint ≈ composed_confint atol=1e-4 # Z Test composed_confint = collect(confint(OneSampleZTest(CM_result_composed_tmle))) - ose_confint = collect(confint(OneSampleZTest(ose(ATE_result₁₀)))) + ose_confint = collect(confint(OneSampleZTest(ATE_ose_result₁₀))) @test ose_confint ≈ composed_confint atol=1e-4 # Variance - @test var(ose(ATE_result₁₀)) ≈ var(CM_result_composed_ose) atol = 1e-3 + @test var(ATE_ose_result₁₀) ≈ var(CM_result_composed_ose) atol = 1e-3 end @testset "Test compose multidimensional function" begin @@ -103,9 +106,9 @@ end ) CM_result₀, _ = tmle!(CM₀, dataset, verbosity=0) f(x, y) = [x^2 - y, x/y, 2x + 3y] - CM_result_composed = compose(f, CM_result₁.tmle, CM_result₀.tmle) + CM_result_composed = compose(f, CM_result₁, CM_result₀) - @test estimate(CM_result_composed) == f(estimate(CM_result₁.tmle), TMLE.estimate(CM_result₀.tmle)) + @test estimate(CM_result_composed) == f(estimate(CM_result₁), TMLE.estimate(CM_result₀)) @test size(var(CM_result_composed)) == (3, 3) end diff --git a/test/double_robustness_ate.jl b/test/double_robustness_ate.jl index eca65cb0..da1a8653 100644 --- a/test/double_robustness_ate.jl +++ b/test/double_robustness_ate.jl @@ -128,9 +128,9 @@ end test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test tmle_result.initial == 0 + @test initial(tmle_result) == 0 # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> LinearRegressor() @@ -157,9 +157,9 @@ end test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-6) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test tmle_result.initial == 0 + @test initial(tmle_result) == 0 # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> LogisticClassifier(lambda=0) @@ -187,9 +187,9 @@ end test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test tmle_result.initial == 0 + @test initial(tmle_result) == 0 # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> LinearRegressor() @@ -218,7 +218,7 @@ end test_coverage(tmle_result, ATE₁₁₋₀₁) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> LinearRegressor() scm.T₁.model = ConstantClassifier() @@ -239,7 +239,7 @@ end test_coverage(tmle_result, ATE₁₁₋₀₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> MLJModels.DeterministicConstantRegressor() diff --git a/test/double_robustness_iate.jl b/test/double_robustness_iate.jl index f08334d9..0dfa450e 100644 --- a/test/double_robustness_iate.jl +++ b/test/double_robustness_iate.jl @@ -191,9 +191,9 @@ end test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-9) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test tmle_result.initial == 0 + @test initial(tmle_result) == 0 # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> LogisticClassifier(lambda=0) @@ -204,9 +204,9 @@ end test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-9) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test tmle_result.initial ≈ -0.0 atol=1e-1 + @test initial(tmle_result) ≈ -0.0 atol=1e-1 end @testset "Test Double Robustness IATE on continuous_outcome_binary_treatment_pb" begin @@ -225,9 +225,9 @@ end test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test tmle_result.initial == 0 + @test initial(tmle_result) == 0 # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> cont_interacter @@ -256,9 +256,9 @@ end tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test tmle_result.initial == 0 + @test initial(tmle_result) == 0 # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> cat_interacter @@ -268,9 +268,9 @@ end tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test tmle_result.initial ≈ -0.02 atol=1e-2 + @test initial(tmle_result) ≈ -0.02 atol=1e-2 end diff --git a/test/estimands.jl b/test/estimands.jl index a87983fd..6caf1f2e 100644 --- a/test/estimands.jl +++ b/test/estimands.jl @@ -247,9 +247,6 @@ end @test isconcretetype(CM) @test isconcretetype(TMLE.TMLEstimate{Float64}) @test isconcretetype(TMLE.OSEstimate{Float64}) - @test isconcretetype(TMLE.TMLEResult{ATE, Float64}) - @test isconcretetype(TMLE.TMLEResult{IATE, Float64}) - @test isconcretetype(TMLE.TMLEResult{CM, Float64}) end end diff --git a/test/estimation.jl b/test/estimation.jl index 339367e4..365b60a9 100644 --- a/test/estimation.jl +++ b/test/estimation.jl @@ -1,4 +1,4 @@ -module TestWarmRestart +module TestEstimation using Test using Tables @@ -77,7 +77,6 @@ table_types = (Tables.columntable, DataFrame) tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset; verbosity=1); # The TMLE covers the ground truth but the initial estimate does not Ψ₀ = -1 - @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -96,7 +95,6 @@ table_types = (Tables.columntable, DataFrame) ) tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = 1 - @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -117,7 +115,6 @@ table_types = (Tables.columntable, DataFrame) ) tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = 0.5 - @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -137,7 +134,6 @@ table_types = (Tables.columntable, DataFrame) ) tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = 10 - @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -156,7 +152,6 @@ table_types = (Tables.columntable, DataFrame) ) tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = -0.5 - @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -174,7 +169,6 @@ table_types = (Tables.columntable, DataFrame) ) tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = -1.5 - @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -199,7 +193,6 @@ end ) tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = 3 - @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -217,7 +210,6 @@ end ) tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = 2.5 - @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -236,7 +228,6 @@ end ) tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = 1.5 - @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -250,7 +241,6 @@ end (:info, "Done.") ) tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); - @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -274,7 +264,6 @@ end ) tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); Ψ₀ = -1 - @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -291,7 +280,6 @@ end (:info, "Done.") ) tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); - @test tmle_result.estimand == Ψ test_coverage(tmle_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) diff --git a/test/fluctuation.jl b/test/fluctuation.jl new file mode 100644 index 00000000..f856bd61 --- /dev/null +++ b/test/fluctuation.jl @@ -0,0 +1,118 @@ +module TestFluctuation + +using Test +using TMLE +using MLJModels +using MLJBase + +function test_fluctuation_fit(Ψ, ps_lowerbound, weighted_fluctuation) + Q⁰ = TMLE.getQ(Ψ) + X, y = Q⁰.data + Q = machine( + TMLE.Fluctuation(Ψ, 0.1, ps_lowerbound, weighted_fluctuation), + X, + y + ) + fit!(Q, verbosity=0) + Xfluct, weights = TMLE.clever_covariate_offset_and_weights(Ψ, Q⁰, X; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) + @test Q.fitresult.data[1] == Xfluct + @test Q.fitresult.data[3] == weights +end + + +@testset "Test Fluctuation with 1 Treatments" begin + ### In all cases, models are "constant" models + ## First case: 1 Treamtent variable + Ψ = ATE( + outcome=:Y, + confounders=[:W₁, :W₂, :W₃], + treatment=(T=(case="a", control="b"),), + treatment_model = ConstantClassifier(), + outcome_model = ConstantRegressor() + ) + dataset = ( + T = categorical(["a", "b", "c", "a", "a", "b", "a"]), + Y = [1., 2., 3, 4, 5, 6, 7], + W₁ = rand(7), + W₂ = rand(7), + W₃ = rand(7) + ) + fit!(Ψ, dataset, verbosity=0) + + ps_lowerbound = 1e-8 + weighted_fluctuation = true + test_fluctuation_fit(Ψ, ps_lowerbound, weighted_fluctuation) + + weighted_fluctuation = false + test_fluctuation_fit(Ψ, ps_lowerbound, weighted_fluctuation) +end + +@testset "Test Fluctuation with 2 Treatments" begin + Ψ = IATE( + outcome =:Y, + treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), + confounders = [:W₁, :W₂, :W₃], + treatment_model = ConstantClassifier(), + outcome_model = ConstantClassifier() + ) + dataset = ( + T₁ = categorical([1, 0, 0, 1, 1, 1, 0]), + T₂ = categorical([1, 1, 1, 1, 1, 0, 0]), + Y = categorical([1, 1, 1, 1, 0, 0, 0]), + W₁ = rand(7), + W₂ = rand(7), + W₃ = rand(7) + ) + ps_lowerbound = 1e-8 + weighted_fluctuation = false + + fit!(Ψ, dataset, verbosity=0) + test_fluctuation_fit(Ψ, ps_lowerbound, weighted_fluctuation) +end + +@testset "Test Fluctuation with 3 Treatments" begin + ## Third case: 3 Treatment variables + Ψ = IATE( + outcome =:Y, + treatment=(T₁=(case="a", control="b"), + T₂=(case=1, control=2), + T₃=(case=true, control=false)), + confounders=:W, + treatment_model = ConstantClassifier(), + outcome_model = DeterministicConstantRegressor() + ) + dataset = ( + T₁ = categorical(["a", "a", "b", "b", "c", "b", "b"]), + T₂ = categorical([3, 2, 1, 1, 2, 2, 2], ordered=true), + T₃ = categorical([true, false, true, false, false, false, false], ordered=true), + Y = [1., 2., 3, 4, 5, 6, 7], + W = rand(7), + ) + ps_lowerbound = 1e-8 + weighted_fluctuation = false + + fit!(Ψ, dataset, verbosity=0) + test_fluctuation_fit(Ψ, ps_lowerbound, weighted_fluctuation) +end + +@testset "Test fluctuation_input" begin + X = TMLE.fluctuation_input([1., 2.], [1., 2]) + @test X.covariate isa Vector{Float64} + @test X.offset isa Vector{Float64} + + X = TMLE.fluctuation_input([1., 2.], [1.f0, 2.f0]) + @test X.covariate isa Vector{Float64} + @test X.offset isa Vector{Float64} + + X = TMLE.fluctuation_input([1.f0, 2.f0], [1., 2.]) + @test X.covariate isa Vector{Float32} + @test X.offset isa Vector{Float32} + +end + +end + +true \ No newline at end of file diff --git a/test/gradient.jl b/test/gradient.jl new file mode 100644 index 00000000..b5747578 --- /dev/null +++ b/test/gradient.jl @@ -0,0 +1,67 @@ +module TestGradient + +using TMLE +using Test +using MLJBase +using StableRNGs +using LogExpFunctions +using Distributions +using MLJLinearModels +using MLJModels + +μY(T, W) = 1 .+ 2T .- W.*T + +function one_treatment_dataset(;n=100) + rng = StableRNG(123) + W = rand(rng, n) + μT = logistic.(1 .- 3W) + T = rand(rng, n) .< μT + Y = μY(T, W) .+ rand(rng, Normal(0, 0.1), n) + return ( + T = categorical(T, ordered=true), + Y = Y, + W = W + ) +end + +@testset "Test gradient_and_estimate" begin + ps_lowerbound = 1e-8 + Ψ = ATE( + outcome = :Y, + treatment = (T=(case=1, control=0),), + confounders = [:W], + treatment_model = LogisticClassifier(), + outcome_model = with_encoder(InteractionTransformer(order=2) |> LinearRegressor()) + ) + dataset = one_treatment_dataset(;n=100) + fit!(Ψ, dataset, verbosity=0) + + # Retrieve outcome model + Q = TMLE.getQ(Ψ) + linear_model = fitted_params(Q).deterministic_pipeline.linear_regressor + intercept = linear_model.intercept + coefs = Dict(linear_model.coefs) + # Counterfactual aggregate + ctf_agg = TMLE.counterfactual_aggregate(Ψ, Q) + expected_ctf_agg = (intercept .+ coefs[:T] .+ dataset.W.*coefs[:W] .+ dataset.W.*coefs[:T_W]) .- (intercept .+ dataset.W.*coefs[:W]) + @test ctf_agg ≈ expected_ctf_agg atol=1e-10 + # Retrieve propensity score model + G = Ψ.scm[:T].mach + # Gradient Y|X + H = 1 ./ pdf.(predict(G), dataset.T) .* [t == 1 ? 1. : -1. for t in dataset.T] + expected_∇YX = H .* (dataset.Y .- predict(Q)) + ∇YX = TMLE.∇YX(Ψ, Q; ps_lowerbound=ps_lowerbound) + @test expected_∇YX == ∇YX + # Gradient W + expectedΨ̂ = mean(ctf_agg) + ∇W = TMLE.∇W(ctf_agg, expectedΨ̂) + @test ∇W == ctf_agg .- expectedΨ̂ + # gradient_and_estimate + IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, Q; ps_lowerbound=ps_lowerbound) + @test expectedΨ̂ == Ψ̂ + @test IC == ∇YX .+ ∇W +end + +end + +true \ No newline at end of file diff --git a/test/helper_fns.jl b/test/helper_fns.jl index e622cefc..a362c0b3 100644 --- a/test/helper_fns.jl +++ b/test/helper_fns.jl @@ -21,7 +21,7 @@ It seems that sometimes this is not entirely true in practice, so the test actua increase risk more than tol """ function test_fluct_decreases_risk(Ψ::TMLE.CMCompositeEstimand, fluctuation_mach; atol=1e-6) - outcome_mach = TMLE.outcome_equation(Ψ).mach + outcome_mach = TMLE.getQ(Ψ) y = outcome_mach.data[2] initial_risk = risk(MLJBase.predict(outcome_mach), y) fluct_risk = risk(MLJBase.predict(fluctuation_mach), y) @@ -30,26 +30,26 @@ end """ - test_coverage(tmle_result::TMLE.TMLEResult, Ψ₀) + test_coverage(tmle_result::TMLE.EICEstimate, Ψ₀) Both the TMLE and OneStep estimators are suppose to be asymptotically efficient and cover the truth at the given confidence level: here 0.05 """ -function test_coverage(result::TMLE.TMLEResult, Ψ₀) +function test_coverage(result::TMLE.EICEstimate, Ψ₀) # TMLE - lb, ub = confint(OneSampleTTest(result.tmle)) + lb, ub = confint(OneSampleTTest(result)) @test lb ≤ Ψ₀ ≤ ub # OneStep - lb, ub = confint(OneSampleTTest(result.onestep)) + lb, ub = confint(OneSampleTTest(result)) @test lb ≤ Ψ₀ ≤ ub end """ - test_mean_inf_curve_almost_zero(tmle_result::TMLE.TMLEResult; atol=1e-10) + test_mean_inf_curve_almost_zero(tmle_result::TMLE.EICEstimate; atol=1e-10) The TMLE is supposed to solve the EIC score equation. """ -test_mean_inf_curve_almost_zero(tmle_result::TMLE.TMLEResult; atol=1e-10) = @test mean(tmle_result.tmle.IC) ≈ 0.0 atol=atol +test_mean_inf_curve_almost_zero(tmle_result::TMLE.EICEstimate; atol=1e-10) = @test mean(tmle_result.IC) ≈ 0.0 atol=atol """ test_fluct_mean_inf_curve_lower_than_initial(tmle_result::TMLE.TMLEResult) @@ -57,4 +57,4 @@ test_mean_inf_curve_almost_zero(tmle_result::TMLE.TMLEResult; atol=1e-10) = @tes This cqnnot be guaranteed in general since a well specified maximum likelihood estimator also solves the score equation. """ -test_fluct_mean_inf_curve_lower_than_initial(tmle_result::TMLE.TMLEResult) = @test abs(mean(tmle_result.tmle.IC)) < abs(mean(tmle_result.onestep.IC)) \ No newline at end of file +test_fluct_mean_inf_curve_lower_than_initial(tmle_result::TMLE.EICEstimate) = @test abs(mean(tmle_result.IC)) < abs(mean(tmle_result.IC)) \ No newline at end of file diff --git a/test/non_regression_test.jl b/test/non_regression_test.jl index 99d5f201..18ca4188 100644 --- a/test/non_regression_test.jl +++ b/test/non_regression_test.jl @@ -16,28 +16,23 @@ using MLJGLMInterface for col in confounders dataset[!, col] = float(dataset[!, col]) end - scm = StaticConfoundedModel( - :haz01, :parity01, confounders, + + Ψ = ATE( + outcome=:haz01, + treatment=(parity01=(case=1, control=0),), + confounders=confounders, outcome_model = TreatmentTransformer() |> LinearBinaryClassifier(), treatment_model = LinearBinaryClassifier() ) - Ψ = ATE(scm, outcome=:haz01, treatment=(parity01=(case=1, control=0),)) - result, fluctuation_mach = tmle!(Ψ, dataset, ps_lowerbound=0.025, verbosity=0) + tmle_result, fluctuation_mach = tmle!(Ψ, dataset, ps_lowerbound=0.025, verbosity=0) # TMLE - tmle_result = tmle(result) @test estimate(tmle_result) ≈ -0.185533 atol = 1e-6 + @test initial(tmle_result) ≈ -0.150078 atol = 1e-6 l, u = confint(OneSampleTTest(tmle_result)) @test l ≈ -0.279246 atol = 1e-6 @test u ≈ -0.091821 atol = 1e-6 @test OneSampleZTest(tmle_result) isa OneSampleZTest - # OSE - ose_result = ose(result) - @test estimate(ose_result) isa Float64 - @test OneSampleTTest(ose_result) isa OneSampleTTest - @test OneSampleZTest(ose_result) isa OneSampleZTest - # Naive - @test initial(result) isa Float64 end diff --git a/test/runtests.jl b/test/runtests.jl index 58a0cc54..94752c9c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,8 @@ using Test @test include("scm.jl") @test include("non_regression_test.jl") @test include("utils.jl") + @test include("gradient.jl") + @test include("fluctuation.jl") @test include("double_robustness_ate.jl") @test include("double_robustness_iate.jl") @test include("3points_interactions.jl") diff --git a/test/utils.jl b/test/utils.jl index a41c1635..266b3b52 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,14 +1,10 @@ module TestUtils using Test -using Tables -using TableOperations using TMLE using MLJBase using StableRNGs -using Distributions using CategoricalArrays -using MLJGLMInterface: LinearBinaryClassifier using MLJLinearModels using MLJModels @@ -154,18 +150,37 @@ end @test !isordered(cfT.T₂) end -@testset "Test Targeting sequence with 1 Treatment" begin +@testset "Test ps_lower_bound" begin + dataset = ( + T = categorical([1, 0, 1, 1, 0, 1, 1]), + Y = [1., 2., 3, 4, 5, 6, 7], + W = rand(7), + ) + Ψ = ATE( + outcome=:Y, + treatment=(T=(case=1, control=0),), + confounders=:W, + ) + max_lb = 0.1 + fit!(Ψ, dataset, verbosity=0) + # Nothing results in data adaptive lower bound no lower than max_lb + @test TMLE.ps_lower_bound(Ψ, nothing) == max_lb + @test TMLE.data_adaptive_ps_lower_bound(Ψ) == max_lb + @test TMLE.data_adaptive_ps_lower_bound(1000) == 0.02984228238321508 + # Otherwise use the provided threhsold provided it's lower than max_lb + @test TMLE.ps_lower_bound(Ψ, 1e-8) == 1.0e-8 + @test TMLE.ps_lower_bound(Ψ, 1) == 0.1 +end + +@testset "Test compute_offset, clever_covariate_and_weights: 1 treatment" begin ### In all cases, models are "constant" models ## First case: 1 Treamtent variable - scm = StaticConfoundedModel( - :Y, :T, [:W₁, :W₂, :W₃], - treatment_model = ConstantClassifier(), - outcome_model = ConstantRegressor() - ) Ψ = ATE( - scm, outcome=:Y, treatment=(T=(case="a", control="b"),), + confounders=[:W₁, :W₂, :W₃], + treatment_model = ConstantClassifier(), + outcome_model = ConstantRegressor() ) dataset = ( T = categorical(["a", "b", "c", "a", "a", "b", "a"]), @@ -176,46 +191,35 @@ end ) fit!(Ψ, dataset, verbosity=0) - @test TMLE.ps_lower_bound(Ψ, nothing) == 0.1 - @test TMLE.data_adaptive_ps_lower_bound(Ψ) == 0.1 - @test TMLE.data_adaptive_ps_lower_bound(1000) == 0.02984228238321508 - @test TMLE.ps_lower_bound(Ψ, 1e-8) == 1.0e-8 - @test TMLE.ps_lower_bound(Ψ, 1) == 0.1 - - indicator_fns = TMLE.indicator_fns(Ψ) - T = TMLE.treatments(dataset, Ψ) - @test T == (T=dataset.T,) - W = TMLE.confounders(dataset, Ψ) - @test W == (T = (W₁=dataset.W₁, W₂=dataset.W₂, W₃=dataset.W₃),) - - @test TMLE.compute_offset(Ψ) == repeat([mean(dataset.Y)], 7) + offset = TMLE.compute_offset(Ψ) + @test offset == repeat([mean(dataset.Y)], 7) weighted_fluctuation = true - bw = TMLE.balancing_weights(scm, W, T) - cov, w = TMLE.clever_covariate_and_weights(scm, W, T, indicator_fns, weighted_fluctuation=weighted_fluctuation) + ps_lowerbound = 1e-8 + X, y = TMLE.getQ(Ψ).data + cov, w = TMLE.clever_covariate_and_weights(Ψ, X; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) @test cov == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] - @test w == bw == [1.75, 3.5, 7.0, 1.75, 1.75, 3.5, 1.75] - - mach = TMLE.tmle_step(Ψ, weighted_fluctuation=weighted_fluctuation) - @test fitted_params(mach).coef[1] isa Float64 + @test w == [1.75, 3.5, 7.0, 1.75, 1.75, 3.5, 1.75] - cov, w = TMLE.clever_covariate_and_weights(scm, W, T, indicator_fns) + weighted_fluctuation = false + cov, w = TMLE.clever_covariate_and_weights(Ψ, X; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) @test cov == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] - - unweighted_mach = TMLE.tmle_step(Ψ, weighted_fluctuation=false) - @test fitted_params(mach).coef[1] isa Float64 + @test w == ones(7) end -@testset "Test Targeting sequence with 2 Treatments" begin - scm = StaticConfoundedModel( - [:Y], [:T₁, :T₂], [:W₁, :W₂, :W₃], - treatment_model = ConstantClassifier(), - outcome_model = ConstantClassifier() - ) +@testset "Test compute_offset, clever_covariate_and_weights: 2 treatments" begin Ψ = IATE( - scm, - outcome =:Y, + outcome = :Y, treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), + confounders=[:W₁, :W₂, :W₃], + treatment_model = ConstantClassifier(), + outcome_model = ConstantClassifier() ) dataset = ( T₁ = categorical([1, 0, 0, 1, 1, 1, 0]), @@ -225,39 +229,33 @@ end W₂ = rand(7), W₃ = rand(7) ) + ps_lowerbound = 1e-8 + weighted_fluctuation = false + fit!(Ψ, dataset, verbosity=0) - indicator_fns = TMLE.indicator_fns(Ψ) - T = TMLE.treatments(dataset, Ψ) - @test T == (T₁ = dataset.T₁, T₂ = dataset.T₂) - W = TMLE.confounders(dataset, Ψ) - @test W == ( - T₁ = (W₁ = dataset.W₁, W₂ = dataset.W₂, W₃ = dataset.W₃), - T₂ = (W₁ = dataset.W₁, W₂ = dataset.W₂, W₃ = dataset.W₃) - ) + # Because the outcome is binary, the offset is the logit - @test TMLE.compute_offset(Ψ) == repeat([0.28768207245178085], 7) - - cov, w = TMLE.clever_covariate_and_weights(scm, W, T, indicator_fns) + offset = TMLE.compute_offset(Ψ) + @test offset == repeat([0.28768207245178085], 7) + X, y = TMLE.getQ(Ψ).data + cov, w = TMLE.clever_covariate_and_weights(Ψ, X; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) @test cov ≈ [2.45, -3.266, -3.266, 2.45, 2.45, -6.125, 8.166] atol=1e-2 @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] - - mach = TMLE.tmle_step(Ψ) - @test fitted_params(mach).coef[1] isa Float64 end -@testset "Test Targeting sequence with 3 Treatments" begin +@testset "Test compute_offset, clever_covariate_and_weights: 3 treatments" begin ## Third case: 3 Treatment variables - scm = StaticConfoundedModel( - [:Y], [:T₁, :T₂, :T₃], [:W], - treatment_model = ConstantClassifier(), - outcome_model = DeterministicConstantRegressor() - ) Ψ = IATE( - scm, outcome =:Y, treatment=(T₁=(case="a", control="b"), T₂=(case=1, control=2), T₃=(case=true, control=false)), + confounders=[:W], + treatment_model = ConstantClassifier(), + outcome_model = DeterministicConstantRegressor() ) dataset = ( T₁ = categorical(["a", "a", "b", "b", "c", "b", "b"]), @@ -266,42 +264,20 @@ end Y = [1., 2., 3, 4, 5, 6, 7], W = rand(7), ) + ps_lowerbound = 1e-8 + weighted_fluctuation = false fit!(Ψ, dataset, verbosity=0) - indicator_fns = TMLE.indicator_fns(Ψ) - T = TMLE.treatments(dataset, Ψ) - @test T == (T₁ = dataset.T₁, T₂ = dataset.T₂, T₃ = dataset.T₃) - W = TMLE.confounders(dataset, Ψ) - @test W == ( - T₁ = (W = dataset.W,), - T₂ = (W = dataset.W,), - T₃ = (W = dataset.W,) - ) - indicator_fns = TMLE.indicator_fns(Ψ) - - @test TMLE.compute_offset(Ψ) == repeat([4.0], 7) - cov, w = TMLE.clever_covariate_and_weights(scm, W, T, indicator_fns) + offset = TMLE.compute_offset(Ψ) + @test offset == repeat([4.0], 7) + X, y = TMLE.getQ(Ψ).data + cov, w = TMLE.clever_covariate_and_weights(Ψ, X; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) @test cov ≈ [0, 8.575, -21.4375, 8.575, 0, -4.2875, -4.2875] atol=1e-3 @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] - - mach = TMLE.tmle_step(Ψ) - @test fitted_params(mach).coef[1] isa Float64 -end - -@testset "Test fluctuation_input" begin - X = TMLE.fluctuation_input([1., 2.], [1., 2]) - @test X.covariate isa Vector{Float64} - @test X.offset isa Vector{Float64} - - X = TMLE.fluctuation_input([1., 2.], [1.f0, 2.f0]) - @test X.covariate isa Vector{Float64} - @test X.offset isa Vector{Float64} - - X = TMLE.fluctuation_input([1.f0, 2.f0], [1., 2.]) - @test X.covariate isa Vector{Float32} - @test X.offset isa Vector{Float32} - end end; From 4b8a8b34658a31f1e91af285bd5327b41151a64a Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 20 Aug 2023 11:35:03 +0100 Subject: [PATCH 050/151] some renames --- src/estimands.jl | 4 +++- src/estimation.jl | 4 ++-- src/fluctuation.jl | 6 +++--- src/gradient.jl | 4 ++-- src/utils.jl | 4 ++-- test/fluctuation.jl | 2 +- test/gradient.jl | 2 +- test/helper_fns.jl | 2 +- test/runtests.jl | 2 +- test/utils.jl | 6 +++--- 10 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/estimands.jl b/src/estimands.jl index b88d4efe..ecc882ca 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -188,7 +188,9 @@ end outcome_equation(Ψ::CMCompositeEstimand) = Ψ.scm[outcome(Ψ)] -getQ(Ψ::CMCompositeEstimand) = outcome_equation(Ψ).mach +get_outcome_model(Ψ::CMCompositeEstimand) = outcome_equation(Ψ).mach + +get_outcome_datas(Ψ::CMCompositeEstimand) = get_outcome_model(Ψ).data treatments(Ψ::CMCompositeEstimand) = collect(keys(Ψ.treatment)) treatments(dataset, Ψ::CMCompositeEstimand) = selectcols(dataset, treatments(Ψ)) diff --git a/src/estimation.jl b/src/estimation.jl index 4713a9d0..ea545a98 100644 --- a/src/estimation.jl +++ b/src/estimation.jl @@ -50,7 +50,7 @@ function tmle!(Ψ::CMCompositeEstimand, dataset; ps_lowerbound = ps_lower_bound(Ψ, ps_lowerbound) # Fit Fluctuation verbosity >= 1 && @info "Performing TMLE..." - Q⁰ = getQ(Ψ) + Q⁰ = get_outcome_model(Ψ) X, y = Q⁰.data Q = machine( Fluctuation(Ψ, 0.1, ps_lowerbound, weighted_fluctuation), @@ -82,7 +82,7 @@ function ose!(Ψ::CMCompositeEstimand, dataset; # Get propensity score truncation threshold ps_lowerbound = ps_lower_bound(Ψ, ps_lowerbound) # Retrieve initial fit - Q = getQ(Ψ) + Q = get_outcome_model(Ψ) # Gradient and estimate IC, Ψ̂ = gradient_and_estimate(Ψ, Q; ps_lowerbound=ps_lowerbound) verbosity >= 1 && @info "Done." diff --git a/src/fluctuation.jl b/src/fluctuation.jl index 4cada85e..b33d965f 100644 --- a/src/fluctuation.jl +++ b/src/fluctuation.jl @@ -16,7 +16,7 @@ end FluctuationModel = Union{ContinuousFluctuation, BinaryFluctuation} function Fluctuation(Ψ, tol, ps_lowerbound, weighted) - outcome_scitype = scitype(getQ(Ψ).data[2]) + outcome_scitype = scitype(get_outcome_datas(Ψ)[2]) if outcome_scitype <: AbstractVector{<:MLJBase.Continuous} return ContinuousFluctuation(Ψ, tol, ps_lowerbound, weighted) elseif outcome_scitype <: AbstractVector{<:Finite} @@ -56,7 +56,7 @@ end function MLJBase.fit(model::FluctuationModel, verbosity, X, y) Ψ = model.Ψ - Q = getQ(Ψ) + Q = get_outcome_model(Ψ) clever_covariate_and_offset, weights = clever_covariate_offset_and_weights( Ψ, Q, X; @@ -75,7 +75,7 @@ end function MLJBase.predict(model::FluctuationModel, one_dimensional_path, X) Ψ = model.Ψ - Q = getQ(Ψ) + Q = get_outcome_model(Ψ) clever_covariate_and_offset, weights = clever_covariate_offset_and_weights( Ψ, Q, X; diff --git a/src/gradient.jl b/src/gradient.jl index c141e6e7..d3615e66 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -9,7 +9,7 @@ For the ATE with binary treatment, confounded by W it is equal to: """ function counterfactual_aggregate(Ψ::CMCompositeEstimand, Q::Machine) - X = getQ(Ψ).data[1] + X = get_outcome_datas(Ψ)[1] Ttemplate = treatments(X, Ψ) n = nrows(Ttemplate) ctf_agg = zeros(n) @@ -41,7 +41,7 @@ end This part of the gradient is evaluated on the original dataset. All quantities have been precomputed and cached. """ function ∇YX(Ψ::CMCompositeEstimand, Q::Machine; ps_lowerbound=1e-8) - X, y = getQ(Ψ).data + X, y = get_outcome_datas(Ψ) H, weights = clever_covariate_and_weights(Ψ, X; ps_lowerbound=ps_lowerbound) y = float(y) gradient_Y_X_fluct = H .* weights .* (y .- expected_value(MLJBase.predict(Q))) diff --git a/src/utils.jl b/src/utils.jl index 9cb1e3ec..fe649c15 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -14,7 +14,7 @@ This startegy is from [this paper](https://academic.oup.com/aje/article/191/9/16 but the study does not show strictly better behaviour of the strategy so not a default for now. """ function data_adaptive_ps_lower_bound(Ψ::CMCompositeEstimand;max_lb=0.1) - n = nrows(getQ(Ψ).data[2]) + n = nrows(get_outcome_datas(Ψ)[2]) return data_adaptive_ps_lower_bound(n; max_lb=max_lb) end @@ -157,7 +157,7 @@ compute_offset(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = exp compute_offset(ŷ::AbstractVector{<:Real}) = expected_value(ŷ) compute_offset(Ψ::CMCompositeEstimand) = - compute_offset(MLJBase.predict(getQ(Ψ))) + compute_offset(MLJBase.predict(get_outcome_model(Ψ))) function balancing_weights(Ψ::CMCompositeEstimand, W, T; ps_lowerbound=1e-8) density = ones(nrows(T)) diff --git a/test/fluctuation.jl b/test/fluctuation.jl index f856bd61..13fe023a 100644 --- a/test/fluctuation.jl +++ b/test/fluctuation.jl @@ -6,7 +6,7 @@ using MLJModels using MLJBase function test_fluctuation_fit(Ψ, ps_lowerbound, weighted_fluctuation) - Q⁰ = TMLE.getQ(Ψ) + Q⁰ = TMLE.get_outcome_model(Ψ) X, y = Q⁰.data Q = machine( TMLE.Fluctuation(Ψ, 0.1, ps_lowerbound, weighted_fluctuation), diff --git a/test/gradient.jl b/test/gradient.jl index b5747578..378409a8 100644 --- a/test/gradient.jl +++ b/test/gradient.jl @@ -37,7 +37,7 @@ end fit!(Ψ, dataset, verbosity=0) # Retrieve outcome model - Q = TMLE.getQ(Ψ) + Q = TMLE.get_outcome_model(Ψ) linear_model = fitted_params(Q).deterministic_pipeline.linear_regressor intercept = linear_model.intercept coefs = Dict(linear_model.coefs) diff --git a/test/helper_fns.jl b/test/helper_fns.jl index a362c0b3..24a6b3af 100644 --- a/test/helper_fns.jl +++ b/test/helper_fns.jl @@ -21,7 +21,7 @@ It seems that sometimes this is not entirely true in practice, so the test actua increase risk more than tol """ function test_fluct_decreases_risk(Ψ::TMLE.CMCompositeEstimand, fluctuation_mach; atol=1e-6) - outcome_mach = TMLE.getQ(Ψ) + outcome_mach = TMLE.get_outcome_model(Ψ) y = outcome_mach.data[2] initial_risk = risk(MLJBase.predict(outcome_mach), y) fluct_risk = risk(MLJBase.predict(fluctuation_mach), y) diff --git a/test/runtests.jl b/test/runtests.jl index 94752c9c..ec693547 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Test - +Ψ @time begin @test include("scm.jl") @test include("non_regression_test.jl") diff --git a/test/utils.jl b/test/utils.jl index 266b3b52..7e14c110 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -195,7 +195,7 @@ end @test offset == repeat([mean(dataset.Y)], 7) weighted_fluctuation = true ps_lowerbound = 1e-8 - X, y = TMLE.getQ(Ψ).data + X, y = TMLE.get_outcome_datas(Ψ) cov, w = TMLE.clever_covariate_and_weights(Ψ, X; ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation @@ -237,7 +237,7 @@ end # Because the outcome is binary, the offset is the logit offset = TMLE.compute_offset(Ψ) @test offset == repeat([0.28768207245178085], 7) - X, y = TMLE.getQ(Ψ).data + X, y = TMLE.get_outcome_datas(Ψ) cov, w = TMLE.clever_covariate_and_weights(Ψ, X; ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation @@ -271,7 +271,7 @@ end offset = TMLE.compute_offset(Ψ) @test offset == repeat([4.0], 7) - X, y = TMLE.getQ(Ψ).data + X, y = TMLE.get_outcome_datas(Ψ) cov, w = TMLE.clever_covariate_and_weights(Ψ, X; ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation From 884bb9fd2b58f6551bb52ed0176d980df46fddc5 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 20 Aug 2023 12:28:45 +0100 Subject: [PATCH 051/151] restructure project --- src/TMLE.jl | 15 +- src/adjustment.jl | 2 +- src/counterfactual_mean_based/estimands.jl | 219 ++++++++++++++++++ .../estimation.jl | 0 .../fluctuation.jl | 0 .../gradient.jl | 0 .../offset_and_covariate.jl | 77 ++++++ src/estimands.jl | 219 ++++-------------- src/utils.jl | 154 +----------- .../3points_interactions.jl | 2 +- .../double_robustness_ate.jl | 2 +- .../double_robustness_iate.jl | 2 +- .../estimands.jl | 77 +++++- .../estimation.jl | 2 +- .../fluctuation.jl | 0 .../gradient.jl | 0 .../non_regression_test.jl | 0 .../offset_and_covariate.jl | 141 +++++++++++ test/runtests.jl | 18 +- test/utils.jl | 202 ---------------- 20 files changed, 575 insertions(+), 557 deletions(-) create mode 100644 src/counterfactual_mean_based/estimands.jl rename src/{ => counterfactual_mean_based}/estimation.jl (100%) rename src/{ => counterfactual_mean_based}/fluctuation.jl (100%) rename src/{ => counterfactual_mean_based}/gradient.jl (100%) create mode 100644 src/counterfactual_mean_based/offset_and_covariate.jl rename test/{ => counterfactual_mean_based}/3points_interactions.jl (96%) rename test/{ => counterfactual_mean_based}/double_robustness_ate.jl (99%) rename test/{ => counterfactual_mean_based}/double_robustness_iate.jl (99%) rename test/{ => counterfactual_mean_based}/estimands.jl (75%) rename test/{ => counterfactual_mean_based}/estimation.jl (99%) rename test/{ => counterfactual_mean_based}/fluctuation.jl (100%) rename test/{ => counterfactual_mean_based}/gradient.jl (100%) rename test/{ => counterfactual_mean_based}/non_regression_test.jl (100%) create mode 100644 test/counterfactual_mean_based/offset_and_covariate.jl diff --git a/src/TMLE.jl b/src/TMLE.jl index 1bbea29c..37418d55 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -41,15 +41,20 @@ export BackdoorAdjustment # INCLUDES # ############################################################################# +include("utils.jl") include("scm.jl") include("estimands.jl") -include("fluctuation.jl") -include("adjustment.jl") -include("gradient.jl") -include("utils.jl") -include("estimation.jl") include("estimate.jl") include("treatment_transformer.jl") +include("adjustment.jl") + +include("counterfactual_mean_based/estimands.jl") +include("counterfactual_mean_based/offset_and_covariate.jl") +include("counterfactual_mean_based/fluctuation.jl") +include("counterfactual_mean_based/gradient.jl") +include("counterfactual_mean_based/estimation.jl") + + # ############################################################################# diff --git a/src/adjustment.jl b/src/adjustment.jl index 6f7ba79c..8aa4cea5 100644 --- a/src/adjustment.jl +++ b/src/adjustment.jl @@ -31,7 +31,7 @@ outcome_input_variables(treatments, confounding_variables, extra_covariates) = u extra_covariates )) -function get_models_input_variables(adjustment_method::BackdoorAdjustment, Ψ::CMCompositeEstimand) +function get_models_input_variables(adjustment_method::BackdoorAdjustment, Ψ::Estimand) models_inputs = [] for treatment in treatments(Ψ) push!(models_inputs, parents(Ψ.scm, treatment)) diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl new file mode 100644 index 00000000..c28855f0 --- /dev/null +++ b/src/counterfactual_mean_based/estimands.jl @@ -0,0 +1,219 @@ +##################################################################### +### Conditional Mean ### +##################################################################### + +""" +# Conditional Mean / CM + +## Definition + +``CM(Y, T=t) = E[Y|do(T=t)]`` + +## Constructors + +- CM(scm::SCM; outcome::Symbol, treatment::NamedTuple) +- CM( + outcome::Symbol, + treatment::NamedTuple, + confounders::Union{Symbol, AbstractVector{Symbol}}, + covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, + outcome_model = with_encoder(LinearRegressor()), + treatment_model = LinearBinaryClassifier() +) +## Example + +```julia +Ψ = CM(scm, outcome=:Y, treatment=(T=1,)) + +``` +""" +struct ConditionalMean <: Estimand + scm::StructuralCausalModel + outcome::Symbol + treatment::NamedTuple + function ConditionalMean(scm, outcome, treatment) + check_parameter_against_scm(scm, outcome, treatment) + return new(scm, outcome, treatment) + end +end + +const CM = ConditionalMean + +indicator_fns(Ψ::CM) = Dict(values(Ψ.treatment) => 1.) + + +##################################################################### +### Average Treatment Effect ### +##################################################################### +""" +# Average Treatment Effect / ATE + +## Definition + +``ATE(Y, T, case, control) = E[Y|do(T=case)] - E[Y|do(T=control)`` + +## Constructors + +- ATE(scm::SCM; outcome::Symbol, treatment::NamedTuple) +- ATE( + outcome::Symbol, + treatment::NamedTuple, + confounders::Union{Symbol, AbstractVector{Symbol}}, + covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, + outcome_model = with_encoder(LinearRegressor()), + treatment_model = LinearBinaryClassifier() +) + +## Example + +```julia +Ψ = ATE(scm, outcome=:Y, treatment=(T=(case=1,control=0),) +``` +""" +struct AverageTreatmentEffect <: Estimand + scm::StructuralCausalModel + outcome::Symbol + treatment::NamedTuple + function AverageTreatmentEffect(scm, outcome, treatment) + check_parameter_against_scm(scm, outcome, treatment) + return new(scm, outcome, treatment) + end +end + +const ATE = AverageTreatmentEffect + +function indicator_fns(Ψ::ATE) + case = [] + control = [] + for treatment in Ψ.treatment + push!(case, treatment.case) + push!(control, treatment.control) + end + return Dict(Tuple(case) => 1., Tuple(control) => -1.) +end + +##################################################################### +### Interaction Average Treatment Effect ### +##################################################################### + +""" +# Interaction Average Treatment Effect / IATE + +## Definition + +For two treatments with case/control settings (1, 0): + +``IATE = E[Y|do(T₁=1, T₂=1)] - E[Y|do(T₁=1, T₂=0)] - E[Y|do(T₁=0, T₂=1)] + E[Y|do(T₁=0, T₂=0)]`` + +## Constructors + +- IATE(scm::SCM; outcome::Symbol, treatment::NamedTuple) +- IATE( + outcome::Symbol, + treatment::NamedTuple, + confounders::Union{Symbol, AbstractVector{Symbol}}, + covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, + outcome_model = with_encoder(LinearRegressor()), + treatment_model = LinearBinaryClassifier() +) + +## Example + +```julia +Ψ = IATE(scm, outcome=:Y, treatment=(T₁=(case=1,control=0), T₂=(case=1,control=0)) +``` +""" +struct InteractionAverageTreatmentEffect <: Estimand + scm::StructuralCausalModel + outcome::Symbol + treatment::NamedTuple + function InteractionAverageTreatmentEffect(scm, outcome, treatment) + check_parameter_against_scm(scm, outcome, treatment) + return new(scm, outcome, treatment) + end +end + +const IATE = InteractionAverageTreatmentEffect + +ncases(value, Ψ::IATE) = sum(value[i] == Ψ.treatment[i].case for i in eachindex(value)) + +function indicator_fns(Ψ::IATE) + N = length(treatments(Ψ)) + key_vals = Pair[] + for cf in Iterators.product((values(Ψ.treatment[T]) for T in treatments(Ψ))...) + push!(key_vals, cf => float((-1)^(N - ncases(cf, Ψ)))) + end + return Dict(key_vals...) +end + +##################################################################### +### Methods ### +##################################################################### + +AVAILABLE_ESTIMANDS = (CM, ATE, IATE) +CMCompositeTypenames = [:CM, :ATE, :IATE] +CMCompositeEstimand = Union{(eval(x) for x in CMCompositeTypenames)...} + +# Define constructors/name for CMCompositeEstimand types +for typename in CMCompositeTypenames + ex = quote + name(::Type{$(typename)}) = string($(typename)) + + $(typename)(scm::SCM; outcome::Symbol, treatment::NamedTuple) = $(typename)(scm, outcome, treatment) + + function $(typename)(; + outcome::Symbol, + treatment::NamedTuple, + confounders::Union{Symbol, AbstractVector{Symbol}}, + covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, + outcome_model = with_encoder(LinearRegressor()), + treatment_model = LinearBinaryClassifier()) + scm = StaticConfoundedModel([outcome], collect(keys(treatment)), confounders; + covariates=covariates, + outcome_model=outcome_model, + treatment_model=treatment_model + ) + return $(typename)(scm; outcome=outcome, treatment=treatment) + end + end + + eval(ex) +end + +VariableNotAChildInSCMError(variable) = ArgumentError(string("Variable ", variable, " is not associated with a Structural Equation in the SCM.")) +TreatmentMustBeInOutcomeParentsError(variable) = ArgumentError(string("Treatment variable ", variable, " must be a parent of the outcome.")) + +function check_parameter_against_scm(scm::SCM, outcome, treatment) + eqs = equations(scm) + haskey(eqs, outcome) || throw(VariableNotAChildInSCMError(outcome)) + for treatment_variable in keys(treatment) + haskey(eqs, treatment_variable) || throw(VariableNotAChildInSCMError(treatment_variable)) + is_upstream(treatment_variable, outcome, scm) || throw(TreatmentMustBeInOutcomeParentsError(treatment_variable)) + end +end + +function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: CMCompositeEstimand + param_string = string( + name(T), + "\n-----", + "\nOutcome: ", Ψ.outcome, + "\nTreatment: ", Ψ.treatment + ) + println(io, param_string) +end + +outcome_equation(Ψ::CMCompositeEstimand) = Ψ.scm[outcome(Ψ)] + +get_outcome_model(Ψ::CMCompositeEstimand) = outcome_equation(Ψ).mach + +get_outcome_datas(Ψ::CMCompositeEstimand) = get_outcome_model(Ψ).data + +outcome(Ψ::CMCompositeEstimand) = Ψ.outcome +outcome(dataset, Ψ::CMCompositeEstimand) = Tables.getcolumn(dataset, outcome(Ψ)) + +function estimand_key(Ψ::CMCompositeEstimand) + return ( + join(treatments(Ψ), "_"), + string(outcome(Ψ)), + ) +end \ No newline at end of file diff --git a/src/estimation.jl b/src/counterfactual_mean_based/estimation.jl similarity index 100% rename from src/estimation.jl rename to src/counterfactual_mean_based/estimation.jl diff --git a/src/fluctuation.jl b/src/counterfactual_mean_based/fluctuation.jl similarity index 100% rename from src/fluctuation.jl rename to src/counterfactual_mean_based/fluctuation.jl diff --git a/src/gradient.jl b/src/counterfactual_mean_based/gradient.jl similarity index 100% rename from src/gradient.jl rename to src/counterfactual_mean_based/gradient.jl diff --git a/src/counterfactual_mean_based/offset_and_covariate.jl b/src/counterfactual_mean_based/offset_and_covariate.jl new file mode 100644 index 00000000..5d344ccd --- /dev/null +++ b/src/counterfactual_mean_based/offset_and_covariate.jl @@ -0,0 +1,77 @@ + +data_adaptive_ps_lower_bound(n::Int; max_lb=0.1) = + min(5 / (√(n)*log(n/5)), max_lb) + +""" + data_adaptive_ps_lower_bound(Ψ::CMCompositeEstimand) + +This startegy is from [this paper](https://academic.oup.com/aje/article/191/9/1640/6580570?login=false) +but the study does not show strictly better behaviour of the strategy so not a default for now. +""" +function data_adaptive_ps_lower_bound(Ψ::CMCompositeEstimand; max_lb=0.1) + n = nrows(get_outcome_datas(Ψ)[2]) + return data_adaptive_ps_lower_bound(n; max_lb=max_lb) +end + +ps_lower_bound(Ψ::CMCompositeEstimand, lower_bound::Nothing; max_lb=0.1) = data_adaptive_ps_lower_bound(Ψ; max_lb=max_lb) +ps_lower_bound(Ψ::CMCompositeEstimand, lower_bound; max_lb=0.1) = min(max_lb, lower_bound) + + +function truncate!(v::AbstractVector, ps_lowerbound::AbstractFloat) + for i in eachindex(v) + v[i] = max(v[i], ps_lowerbound) + end +end + +function compute_offset(ŷ::UnivariateFiniteVector{Multiclass{2}}) + μy = expected_value(ŷ) + logit!(μy) + return μy +end +compute_offset(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = expected_value(ŷ) +compute_offset(ŷ::AbstractVector{<:Real}) = expected_value(ŷ) + +compute_offset(Ψ::CMCompositeEstimand) = + compute_offset(MLJBase.predict(get_outcome_model(Ψ))) + +function balancing_weights(Ψ::CMCompositeEstimand, W, T; ps_lowerbound=1e-8) + density = ones(nrows(T)) + for colname ∈ Tables.columnnames(T) + mach = Ψ.scm[colname].mach + ŷ = MLJBase.predict(mach, W[colname]) + density .*= pdf.(ŷ, Tables.getcolumn(T, colname)) + end + truncate!(density, ps_lowerbound) + return 1. ./ density +end + +""" + clever_covariate_and_weights(jointT, W, G, indicator_fns; ps_lowerbound=1e-8, weighted_fluctuation=false) + +Computes the clever covariate and weights that are used to fluctuate the initial Q. + +if `weighted_fluctuation = false`: + +- ``clever_covariate(t, w) = \\frac{SpecialIndicator(t)}{p(t|w)}`` +- ``weight(t, w) = 1`` + +if `weighted_fluctuation = true`: + +- ``clever_covariate(t, w) = SpecialIndicator(t)`` +- ``weight(t, w) = \\frac{1}{p(t|w)}`` + +where SpecialIndicator(t) is defined in `indicator_fns`. +""" +function clever_covariate_and_weights(Ψ::CMCompositeEstimand, X; ps_lowerbound=1e-8, weighted_fluctuation=false) + # Compute the indicator values + T = treatments(X, Ψ) + W = confounders(X, Ψ) + indic_vals = indicator_values(indicator_fns(Ψ), T) + weights = balancing_weights(Ψ, W, T, ps_lowerbound=ps_lowerbound) + if weighted_fluctuation + return indic_vals, weights + end + # Vanilla unweighted fluctuation + indic_vals .*= weights + return indic_vals, ones(size(weights, 1)) +end diff --git a/src/estimands.jl b/src/estimands.jl index ecc882ca..91d1bcf2 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -6,206 +6,65 @@ A Estimand is a functional on distribution space Ψ: ℳ → ℜ. """ abstract type Estimand end -##################################################################### -### Conditional Mean ### -##################################################################### - -""" -# Conditional Mean / CM +treatments(Ψ::Estimand) = collect(keys(Ψ.treatment)) +treatments(dataset, Ψ::Estimand) = selectcols(dataset, treatments(Ψ)) -## Definition +confounders(dataset, Ψ::Estimand) = (;(T => selectcols(dataset, keys(Ψ.scm[T].mach.data[1])) for T in treatments(Ψ))...) -``CM(Y, T=t) = E[Y|do(T=t)]`` +AbsentLevelError(treatment_name, key, val, levels) = ArgumentError(string( + "The treatment variable ", treatment_name, "'s, '", key, "' level: '", val, + "' in Ψ does not match any level in the dataset: ", levels)) -## Constructors +AbsentLevelError(treatment_name, val, levels) = ArgumentError(string( + "The treatment variable ", treatment_name, "'s, level: '", val, + "' in Ψ does not match any level in the dataset: ", levels)) -- CM(scm::SCM; outcome::Symbol, treatment::NamedTuple) -- CM( - outcome::Symbol, - treatment::NamedTuple, - confounders::Union{Symbol, AbstractVector{Symbol}}, - covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, - outcome_model = with_encoder(LinearRegressor()), - treatment_model = LinearBinaryClassifier() -) -## Example +""" + check_treatment_settings(settings::NamedTuple, levels, treatment_name) -```julia -Ψ = CM(scm, outcome=:Y, treatment=(T=1,)) +Checks the case/control values defining the treatment contrast are present in the dataset levels. -``` +Note: This method is for estimands like the ATE or IATE that have case/control treatment settings represented as +`NamedTuple`. """ -struct ConditionalMean <: Estimand - scm::StructuralCausalModel - outcome::Symbol - treatment::NamedTuple - function ConditionalMean(scm, outcome, treatment) - check_parameter_against_scm(scm, outcome, treatment) - return new(scm, outcome, treatment) +function check_treatment_settings(settings::NamedTuple, levels, treatment_name) + for (key, val) in zip(keys(settings), settings) + any(val .== levels) || + throw(AbsentLevelError(treatment_name, key, val, levels)) end end -const CM = ConditionalMean - -##################################################################### -### Average Treatment Effect ### -##################################################################### """ -# Average Treatment Effect / ATE - -## Definition - -``ATE(Y, T, case, control) = E[Y|do(T=case)] - E[Y|do(T=control)`` - -## Constructors - -- ATE(scm::SCM; outcome::Symbol, treatment::NamedTuple) -- ATE( - outcome::Symbol, - treatment::NamedTuple, - confounders::Union{Symbol, AbstractVector{Symbol}}, - covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, - outcome_model = with_encoder(LinearRegressor()), - treatment_model = LinearBinaryClassifier() -) + check_treatment_settings(setting, levels, treatment_name) -## Example +Checks the value defining the treatment setting is present in the dataset levels. -```julia -Ψ = ATE(scm, outcome=:Y, treatment=(T=(case=1,control=0),) -``` +Note: This is for estimands like the CM that do not have case/control treatment settings +and are represented as simple values. """ -struct AverageTreatmentEffect <: Estimand - scm::StructuralCausalModel - outcome::Symbol - treatment::NamedTuple - function AverageTreatmentEffect(scm, outcome, treatment) - check_parameter_against_scm(scm, outcome, treatment) - return new(scm, outcome, treatment) - end +function check_treatment_settings(setting, levels, treatment_name) + any(setting .== levels) || + throw( + AbsentLevelError(treatment_name, setting, levels)) end -const ATE = AverageTreatmentEffect - -##################################################################### -### Interaction Average Treatment Effect ### -##################################################################### - """ -# Interaction Average Treatment Effect / IATE - -## Definition - -For two treatments with case/control settings (1, 0): - -``IATE = E[Y|do(T₁=1, T₂=1)] - E[Y|do(T₁=1, T₂=0)] - E[Y|do(T₁=0, T₂=1)] + E[Y|do(T₁=0, T₂=0)]`` - -## Constructors - -- IATE(scm::SCM; outcome::Symbol, treatment::NamedTuple) -- IATE( - outcome::Symbol, - treatment::NamedTuple, - confounders::Union{Symbol, AbstractVector{Symbol}}, - covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, - outcome_model = with_encoder(LinearRegressor()), - treatment_model = LinearBinaryClassifier() -) - -## Example + check_treatment_levels(Ψ::Estimand, dataset) -```julia -Ψ = IATE(scm, outcome=:Y, treatment=(T₁=(case=1,control=0), T₂=(case=1,control=0)) -``` +Makes sure the defined treatment levels are present in the dataset. """ -struct InteractionAverageTreatmentEffect <: Estimand - scm::StructuralCausalModel - outcome::Symbol - treatment::NamedTuple - function InteractionAverageTreatmentEffect(scm, outcome, treatment) - check_parameter_against_scm(scm, outcome, treatment) - return new(scm, outcome, treatment) +function check_treatment_levels(Ψ::Estimand, dataset) + for treatment_name in treatments(Ψ) + treatment_levels = levels(Tables.getcolumn(dataset, treatment_name)) + treatment_settings = getproperty(Ψ.treatment, treatment_name) + check_treatment_settings(treatment_settings, treatment_levels, treatment_name) end end -const IATE = InteractionAverageTreatmentEffect - -##################################################################### -### Methods ### -##################################################################### - -AVAILABLE_ESTIMANDS = (CM, ATE, IATE) -CMCompositeTypenames = [:CM, :ATE, :IATE] -CMCompositeEstimand = Union{(eval(x) for x in CMCompositeTypenames)...} - -# Define constructors/name for CMCompositeEstimand types -for typename in CMCompositeTypenames - ex = quote - name(::Type{$(typename)}) = string($(typename)) - - $(typename)(scm::SCM; outcome::Symbol, treatment::NamedTuple) = $(typename)(scm, outcome, treatment) - - function $(typename)(; - outcome::Symbol, - treatment::NamedTuple, - confounders::Union{Symbol, AbstractVector{Symbol}}, - covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, - outcome_model = with_encoder(LinearRegressor()), - treatment_model = LinearBinaryClassifier()) - scm = StaticConfoundedModel([outcome], collect(keys(treatment)), confounders; - covariates=covariates, - outcome_model=outcome_model, - treatment_model=treatment_model - ) - return $(typename)(scm; outcome=outcome, treatment=treatment) - end - end - - eval(ex) -end - -VariableNotAChildInSCMError(variable) = ArgumentError(string("Variable ", variable, " is not associated with a Structural Equation in the SCM.")) -TreatmentMustBeInOutcomeParentsError(variable) = ArgumentError(string("Treatment variable ", variable, " must be a parent of the outcome.")) - -function check_parameter_against_scm(scm::SCM, outcome, treatment) - eqs = equations(scm) - haskey(eqs, outcome) || throw(VariableNotAChildInSCMError(outcome)) - for treatment_variable in keys(treatment) - haskey(eqs, treatment_variable) || throw(VariableNotAChildInSCMError(treatment_variable)) - is_upstream(treatment_variable, outcome, scm) || throw(TreatmentMustBeInOutcomeParentsError(treatment_variable)) - end -end - -function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: CMCompositeEstimand - param_string = string( - name(T), - "\n-----", - "\nOutcome: ", Ψ.outcome, - "\nTreatment: ", Ψ.treatment - ) - println(io, param_string) -end - -outcome_equation(Ψ::CMCompositeEstimand) = Ψ.scm[outcome(Ψ)] - -get_outcome_model(Ψ::CMCompositeEstimand) = outcome_equation(Ψ).mach - -get_outcome_datas(Ψ::CMCompositeEstimand) = get_outcome_model(Ψ).data - -treatments(Ψ::CMCompositeEstimand) = collect(keys(Ψ.treatment)) -treatments(dataset, Ψ::CMCompositeEstimand) = selectcols(dataset, treatments(Ψ)) - -outcome(Ψ::CMCompositeEstimand) = Ψ.outcome -outcome(dataset, Ψ::CMCompositeEstimand) = Tables.getcolumn(dataset, outcome(Ψ)) - -confounders(dataset, Ψ) = (;(T => selectcols(dataset, keys(Ψ.scm[T].mach.data[1])) for T in treatments(Ψ))...) - -function param_key(Ψ::CMCompositeEstimand) - return ( - join(treatments(Ψ), "_"), - string(outcome(Ψ)), - ) -end +""" +Function used to sort estimands for optimal estimation ordering. +""" +function estimand_key end """ optimize_ordering!(estimands::Vector{<:Estimand}) @@ -213,11 +72,11 @@ end Optimizes the order of the `estimands` to maximize reuse of fitted equations in the associated SCM. """ -optimize_ordering!(estimands::Vector{<:Estimand}) = sort!(estimands, by=param_key) +optimize_ordering!(estimands::Vector{<:Estimand}) = sort!(estimands, by=estimand_key) """ optimize_ordering(estimands::Vector{<:Estimand}) See [`optimize_ordering!`](@ref) """ -optimize_ordering(estimands::Vector{<:Estimand}) = sort(estimands, by=param_key) \ No newline at end of file +optimize_ordering(estimands::Vector{<:Estimand}) = sort(estimands, by=estimand_key) \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index fe649c15..c3e1470d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,36 +4,12 @@ selectcols(data, cols) = data |> TableOperations.select(cols...) |> Tables.columntable -data_adaptive_ps_lower_bound(n::Int; max_lb=0.1) = - min(5 / (√(n)*log(n/5)), max_lb) - -""" - data_adaptive_ps_lower_bound(Ψ::CMCompositeEstimand) - -This startegy is from [this paper](https://academic.oup.com/aje/article/191/9/1640/6580570?login=false) -but the study does not show strictly better behaviour of the strategy so not a default for now. -""" -function data_adaptive_ps_lower_bound(Ψ::CMCompositeEstimand;max_lb=0.1) - n = nrows(get_outcome_datas(Ψ)[2]) - return data_adaptive_ps_lower_bound(n; max_lb=max_lb) -end - -ps_lower_bound(Ψ::CMCompositeEstimand, lower_bound::Nothing; max_lb=0.1) = data_adaptive_ps_lower_bound(Ψ; max_lb=max_lb) -ps_lower_bound(Ψ::CMCompositeEstimand, lower_bound; max_lb=0.1) = min(max_lb, lower_bound) - - function logit!(v) for i in eachindex(v) v[i] = logit(v[i]) end end -function truncate!(v::AbstractVector, ps_lowerbound) - for i in eachindex(v) - v[i] = max(v[i], ps_lowerbound) - end -end - function nomissing(table) sch = Tables.schema(table) for type in sch.types @@ -52,29 +28,6 @@ function nomissing(table, columns) return nomissing(columns) end -ncases(value, Ψ::Estimand) = sum(value[i] == Ψ.treatment[i].case for i in eachindex(value)) - -function indicator_fns(Ψ::IATE) - N = length(treatments(Ψ)) - key_vals = Pair[] - for cf in Iterators.product((values(Ψ.treatment[T]) for T in treatments(Ψ))...) - push!(key_vals, cf => float((-1)^(N - ncases(cf, Ψ)))) - end - return Dict(key_vals...) -end - -indicator_fns(Ψ::CM) = Dict(values(Ψ.treatment) => 1.) - -function indicator_fns(Ψ::ATE) - case = [] - control = [] - for treatment in Ψ.treatment - push!(case, treatment.case) - push!(control, treatment.control) - end - return Dict(Tuple(case) => 1., Tuple(control) => -1.) -end - function indicator_values(indicators, T) indic = zeros(Float64, nrows(T)) for (index, row) in enumerate(Tables.namedtupleiterator(T)) @@ -83,58 +36,6 @@ function indicator_values(indicators, T) return indic end -NotIdentifiedError(reasons) = ArgumentError(string( - "The estimand is not identified for the following reasons: \n\t- ", join(reasons, "\n\t- "))) - -AbsentLevelError(treatment_name, key, val, levels) = ArgumentError(string( - "The treatment variable ", treatment_name, "'s, '", key, "' level: '", val, - "' in Ψ does not match any level in the dataset: ", levels)) - -AbsentLevelError(treatment_name, val, levels) = ArgumentError(string( - "The treatment variable ", treatment_name, "'s, level: '", val, - "' in Ψ does not match any level in the dataset: ", levels)) - -""" - check_treatment_settings(settings::NamedTuple, levels, treatment_name) - -Checks the case/control values defining the treatment contrast are present in the dataset levels. - -Note: This method is for estimands like the ATE or IATE that have case/control treatment settings represented as -`NamedTuple`. -""" -function check_treatment_settings(settings::NamedTuple, levels, treatment_name) - for (key, val) in zip(keys(settings), settings) - any(val .== levels) || - throw(AbsentLevelError(treatment_name, key, val, levels)) - end -end - -""" - check_treatment_settings(setting, levels, treatment_name) - -Checks the value defining the treatment setting is present in the dataset levels. - -Note: This is for estimands like the CM that do not have case/control treatment settings -and are represented as simple values. -""" -function check_treatment_settings(setting, levels, treatment_name) - any(setting .== levels) || - throw( - AbsentLevelError(treatment_name, setting, levels)) -end - -""" - check_treatment_levels(Ψ::CMCompositeEstimand, dataset) - -Makes sure the defined treatment levels are present in the dataset. -""" -function check_treatment_levels(Ψ::CMCompositeEstimand, dataset) - for treatment_name in treatments(Ψ) - treatment_levels = levels(Tables.getcolumn(dataset, treatment_name)) - treatment_settings = getproperty(Ψ.treatment, treatment_name) - check_treatment_settings(treatment_settings, treatment_levels, treatment_name) - end -end expected_value(ŷ::UnivariateFiniteVector{Multiclass{2}}) = pdf.(ŷ, levels(first(ŷ))[2]) expected_value(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = mean.(ŷ) @@ -146,57 +47,4 @@ function counterfactualTreatment(vals, T) NamedTuple{Tnames}( [categorical(repeat([vals[i]], n), levels=levels(Tables.getcolumn(T, name)), ordered=isordered(Tables.getcolumn(T, name))) for (i, name) in enumerate(Tnames)]) -end - -function compute_offset(ŷ::UnivariateFiniteVector{Multiclass{2}}) - μy = expected_value(ŷ) - logit!(μy) - return μy -end -compute_offset(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = expected_value(ŷ) -compute_offset(ŷ::AbstractVector{<:Real}) = expected_value(ŷ) - -compute_offset(Ψ::CMCompositeEstimand) = - compute_offset(MLJBase.predict(get_outcome_model(Ψ))) - -function balancing_weights(Ψ::CMCompositeEstimand, W, T; ps_lowerbound=1e-8) - density = ones(nrows(T)) - for colname ∈ Tables.columnnames(T) - mach = Ψ.scm[colname].mach - ŷ = MLJBase.predict(mach, W[colname]) - density .*= pdf.(ŷ, Tables.getcolumn(T, colname)) - end - truncate!(density, ps_lowerbound) - return 1. ./ density -end - -""" - clever_covariate_and_weights(jointT, W, G, indicator_fns; ps_lowerbound=1e-8, weighted_fluctuation=false) - -Computes the clever covariate and weights that are used to fluctuate the initial Q. - -if `weighted_fluctuation = false`: - -- ``clever_covariate(t, w) = \\frac{SpecialIndicator(t)}{p(t|w)}`` -- ``weight(t, w) = 1`` - -if `weighted_fluctuation = true`: - -- ``clever_covariate(t, w) = SpecialIndicator(t)`` -- ``weight(t, w) = \\frac{1}{p(t|w)}`` - -where SpecialIndicator(t) is defined in `indicator_fns`. -""" -function clever_covariate_and_weights(Ψ::CMCompositeEstimand, X; ps_lowerbound=1e-8, weighted_fluctuation=false) - # Compute the indicator values - T = treatments(X, Ψ) - W = confounders(X, Ψ) - indic_vals = indicator_values(indicator_fns(Ψ), T) - weights = balancing_weights(Ψ, W, T, ps_lowerbound=ps_lowerbound) - if weighted_fluctuation - return indic_vals, weights - end - # Vanilla unweighted fluctuation - indic_vals .*= weights - return indic_vals, ones(size(weights, 1)) -end +end \ No newline at end of file diff --git a/test/3points_interactions.jl b/test/counterfactual_mean_based/3points_interactions.jl similarity index 96% rename from test/3points_interactions.jl rename to test/counterfactual_mean_based/3points_interactions.jl index ceb91b6e..382748c9 100644 --- a/test/3points_interactions.jl +++ b/test/counterfactual_mean_based/3points_interactions.jl @@ -10,7 +10,7 @@ using CategoricalArrays using Test using LogExpFunctions -include("helper_fns.jl") +include(joinpath(dirname(@__DIR__), "helper_fns.jl")) function dataset_scm_and_truth(;n=1000) rng = StableRNG(123) diff --git a/test/double_robustness_ate.jl b/test/counterfactual_mean_based/double_robustness_ate.jl similarity index 99% rename from test/double_robustness_ate.jl rename to test/counterfactual_mean_based/double_robustness_ate.jl index da1a8653..915bfc0e 100644 --- a/test/double_robustness_ate.jl +++ b/test/counterfactual_mean_based/double_robustness_ate.jl @@ -12,7 +12,7 @@ using StatsBase using HypothesisTests using LogExpFunctions -include("helper_fns.jl") +include(joinpath(dirname(@__DIR__), "helper_fns.jl")) """ Q and G are two logistic models diff --git a/test/double_robustness_iate.jl b/test/counterfactual_mean_based/double_robustness_iate.jl similarity index 99% rename from test/double_robustness_iate.jl rename to test/counterfactual_mean_based/double_robustness_iate.jl index 0dfa450e..c04955bb 100644 --- a/test/double_robustness_iate.jl +++ b/test/counterfactual_mean_based/double_robustness_iate.jl @@ -12,7 +12,7 @@ using MLJModels using MLJLinearModels using LogExpFunctions -include("helper_fns.jl") +include(joinpath(dirname(@__DIR__), "helper_fns.jl")) cont_interacter = InteractionTransformer(order=2) |> LinearRegressor cat_interacter = InteractionTransformer(order=2) |> LogisticClassifier(lambda=1.) diff --git a/test/estimands.jl b/test/counterfactual_mean_based/estimands.jl similarity index 75% rename from test/estimands.jl rename to test/counterfactual_mean_based/estimands.jl index 6caf1f2e..6be5e712 100644 --- a/test/estimands.jl +++ b/test/counterfactual_mean_based/estimands.jl @@ -215,9 +215,9 @@ end treatment=(T₂=(case="AA", control="AC"),), ), ] - # Test param_key - @test TMLE.param_key(estimands[1]) == ("T₁_T₂", "Y₁") - @test TMLE.param_key(estimands[end]) == ("T₂", "Y₂") + # Test estimand_key + @test TMLE.estimand_key(estimands[1]) == ("T₁_T₂", "Y₁") + @test TMLE.estimand_key(estimands[end]) == ("T₂", "Y₂") # Non mutating function estimands = shuffle(rng, estimands) ordered_estimands = optimize_ordering(estimands) @@ -240,6 +240,77 @@ end @test estimands == ordered_estimands end +@testset "Test indicator_fns & indicator_values" begin + scm = StaticConfoundedModel( + [:Y], + [:T₁, :T₂, :T₃], + [:W] + ) + dataset = ( + W = [1, 2, 3, 4, 5, 6, 7, 8], + T₁ = ["A", "B", "A", "B", "A", "B", "A", "B"], + T₂ = [0, 0, 1, 1, 0, 0, 1, 1], + T₃ = ["C", "C", "C", "C", "D", "D", "D", "D"], + Y = [1, 1, 1, 1, 1, 1, 1, 1] + ) + # Conditional Mean + Ψ = CM( + scm, + outcome=:Y, + treatment=(T₁="A", T₂=1), + ) + indicator_fns = TMLE.indicator_fns(Ψ) + @test indicator_fns == Dict(("A", 1) => 1.) + indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) + @test indic_values == [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0] + # ATE + Ψ = ATE( + scm, + outcome=:Y, + treatment=(T₁=(case="A", control="B"), T₂=(control=0, case=1)), + ) + indicator_fns = TMLE.indicator_fns(Ψ) + @test indicator_fns == Dict( + ("A", 1) => 1.0, + ("B", 0) => -1.0 + ) + indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) + @test indic_values == [0.0, -1.0, 1.0, 0.0, 0.0, -1.0, 1.0, 0.0] + # 2-points IATE + Ψ = IATE( + scm, + outcome=:Y, + treatment=(T₁=(case="A", control="B"), T₂=(case=1, control=0)), + ) + indicator_fns = TMLE.indicator_fns(Ψ) + @test indicator_fns == Dict( + ("A", 1) => 1.0, + ("A", 0) => -1.0, + ("B", 1) => -1.0, + ("B", 0) => 1.0 + ) + indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) + @test indic_values == [-1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0] + # 3-points IATE + Ψ = IATE( + scm, + outcome=:Y, + treatment=(T₁=(case="A", control="B"), T₂=(case=1, control=0), T₃=(control="D", case="C")), + ) + indicator_fns = TMLE.indicator_fns(Ψ) + @test indicator_fns == Dict( + ("A", 1, "D") => -1.0, + ("A", 1, "C") => 1.0, + ("B", 0, "D") => -1.0, + ("B", 0, "C") => 1.0, + ("B", 1, "C") => -1.0, + ("A", 0, "D") => 1.0, + ("B", 1, "D") => 1.0, + ("A", 0, "C") => -1.0 + ) + indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) + @test indic_values == [-1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0] +end @testset "Test structs are concrete types" begin @test isconcretetype(ATE) diff --git a/test/estimation.jl b/test/counterfactual_mean_based/estimation.jl similarity index 99% rename from test/estimation.jl rename to test/counterfactual_mean_based/estimation.jl index 365b60a9..6ce83b3f 100644 --- a/test/estimation.jl +++ b/test/counterfactual_mean_based/estimation.jl @@ -10,7 +10,7 @@ using MLJLinearModels using CategoricalArrays using LogExpFunctions -include("helper_fns.jl") +include(joinpath(dirname(@__DIR__), "helper_fns.jl")) """ Results derived by hand for this dataset: diff --git a/test/fluctuation.jl b/test/counterfactual_mean_based/fluctuation.jl similarity index 100% rename from test/fluctuation.jl rename to test/counterfactual_mean_based/fluctuation.jl diff --git a/test/gradient.jl b/test/counterfactual_mean_based/gradient.jl similarity index 100% rename from test/gradient.jl rename to test/counterfactual_mean_based/gradient.jl diff --git a/test/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl similarity index 100% rename from test/non_regression_test.jl rename to test/counterfactual_mean_based/non_regression_test.jl diff --git a/test/counterfactual_mean_based/offset_and_covariate.jl b/test/counterfactual_mean_based/offset_and_covariate.jl new file mode 100644 index 00000000..e489d360 --- /dev/null +++ b/test/counterfactual_mean_based/offset_and_covariate.jl @@ -0,0 +1,141 @@ +module TestOffsetAndCovariate + +using Test +using TMLE +using CategoricalArrays +using MLJModels +using Statistics + +@testset "Test ps_lower_bound" begin + dataset = ( + T = categorical([1, 0, 1, 1, 0, 1, 1]), + Y = [1., 2., 3, 4, 5, 6, 7], + W = rand(7), + ) + Ψ = ATE( + outcome=:Y, + treatment=(T=(case=1, control=0),), + confounders=:W, + ) + max_lb = 0.1 + fit!(Ψ, dataset, verbosity=0) + # Nothing results in data adaptive lower bound no lower than max_lb + @test TMLE.ps_lower_bound(Ψ, nothing) == max_lb + @test TMLE.data_adaptive_ps_lower_bound(Ψ) == max_lb + @test TMLE.data_adaptive_ps_lower_bound(1000) == 0.02984228238321508 + # Otherwise use the provided threhsold provided it's lower than max_lb + @test TMLE.ps_lower_bound(Ψ, 1e-8) == 1.0e-8 + @test TMLE.ps_lower_bound(Ψ, 1) == 0.1 +end + +@testset "Test compute_offset, clever_covariate_and_weights: 1 treatment" begin + ### In all cases, models are "constant" models + ## First case: 1 Treamtent variable + Ψ = ATE( + outcome=:Y, + treatment=(T=(case="a", control="b"),), + confounders=[:W₁, :W₂, :W₃], + treatment_model = ConstantClassifier(), + outcome_model = ConstantRegressor() + ) + dataset = ( + T = categorical(["a", "b", "c", "a", "a", "b", "a"]), + Y = [1., 2., 3, 4, 5, 6, 7], + W₁ = rand(7), + W₂ = rand(7), + W₃ = rand(7) + ) + fit!(Ψ, dataset, verbosity=0) + + offset = TMLE.compute_offset(Ψ) + @test offset == repeat([mean(dataset.Y)], 7) + weighted_fluctuation = true + ps_lowerbound = 1e-8 + X, y = TMLE.get_outcome_datas(Ψ) + cov, w = TMLE.clever_covariate_and_weights(Ψ, X; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) + + @test cov == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] + @test w == [1.75, 3.5, 7.0, 1.75, 1.75, 3.5, 1.75] + + weighted_fluctuation = false + cov, w = TMLE.clever_covariate_and_weights(Ψ, X; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) + @test cov == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] + @test w == ones(7) +end + +@testset "Test compute_offset, clever_covariate_and_weights: 2 treatments" begin + Ψ = IATE( + outcome = :Y, + treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), + confounders=[:W₁, :W₂, :W₃], + treatment_model = ConstantClassifier(), + outcome_model = ConstantClassifier() + ) + dataset = ( + T₁ = categorical([1, 0, 0, 1, 1, 1, 0]), + T₂ = categorical([1, 1, 1, 1, 1, 0, 0]), + Y = categorical([1, 1, 1, 1, 0, 0, 0]), + W₁ = rand(7), + W₂ = rand(7), + W₃ = rand(7) + ) + ps_lowerbound = 1e-8 + weighted_fluctuation = false + + fit!(Ψ, dataset, verbosity=0) + + # Because the outcome is binary, the offset is the logit + offset = TMLE.compute_offset(Ψ) + @test offset == repeat([0.28768207245178085], 7) + X, y = TMLE.get_outcome_datas(Ψ) + cov, w = TMLE.clever_covariate_and_weights(Ψ, X; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) + @test cov ≈ [2.45, -3.266, -3.266, 2.45, 2.45, -6.125, 8.166] atol=1e-2 + @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] +end + +@testset "Test compute_offset, clever_covariate_and_weights: 3 treatments" begin + ## Third case: 3 Treatment variables + Ψ = IATE( + outcome =:Y, + treatment=(T₁=(case="a", control="b"), + T₂=(case=1, control=2), + T₃=(case=true, control=false)), + confounders=[:W], + treatment_model = ConstantClassifier(), + outcome_model = DeterministicConstantRegressor() + ) + dataset = ( + T₁ = categorical(["a", "a", "b", "b", "c", "b", "b"]), + T₂ = categorical([3, 2, 1, 1, 2, 2, 2], ordered=true), + T₃ = categorical([true, false, true, false, false, false, false], ordered=true), + Y = [1., 2., 3, 4, 5, 6, 7], + W = rand(7), + ) + ps_lowerbound = 1e-8 + weighted_fluctuation = false + + fit!(Ψ, dataset, verbosity=0) + + offset = TMLE.compute_offset(Ψ) + @test offset == repeat([4.0], 7) + X, y = TMLE.get_outcome_datas(Ψ) + cov, w = TMLE.clever_covariate_and_weights(Ψ, X; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) + @test cov ≈ [0, 8.575, -21.4375, 8.575, 0, -4.2875, -4.2875] atol=1e-3 + @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] +end + +end + +true \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index ec693547..a10bcb65 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,17 +1,17 @@ using Test -Ψ + @time begin @test include("scm.jl") - @test include("non_regression_test.jl") @test include("utils.jl") - @test include("gradient.jl") - @test include("fluctuation.jl") - @test include("double_robustness_ate.jl") - @test include("double_robustness_iate.jl") - @test include("3points_interactions.jl") - @test include("estimation.jl") - @test include("estimands.jl") @test include("missing_management.jl") @test include("composition.jl") @test include("treatment_transformer.jl") + @test include("counterfactual_mean_based/non_regression_test.jl") + @test include("counterfactual_mean_based/gradient.jl") + @test include("counterfactual_mean_based/fluctuation.jl") + @test include("counterfactual_mean_based/double_robustness_ate.jl") + @test include("counterfactual_mean_based/double_robustness_iate.jl") + @test include("counterfactual_mean_based/3points_interactions.jl") + @test include("counterfactual_mean_based/estimation.jl") + @test include("counterfactual_mean_based/estimands.jl") end \ No newline at end of file diff --git a/test/utils.jl b/test/utils.jl index 7e14c110..f1afcf66 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -34,78 +34,6 @@ using MLJModels @test_throws TMLE.AbsentLevelError("T", 0, [1, 2, 3]) TMLE.check_treatment_levels(Ψ, dataset) end -@testset "Test indicator_fns & indicator_values" begin - scm = StaticConfoundedModel( - [:Y], - [:T₁, :T₂, :T₃], - [:W] - ) - dataset = ( - W = [1, 2, 3, 4, 5, 6, 7, 8], - T₁ = ["A", "B", "A", "B", "A", "B", "A", "B"], - T₂ = [0, 0, 1, 1, 0, 0, 1, 1], - T₃ = ["C", "C", "C", "C", "D", "D", "D", "D"], - Y = [1, 1, 1, 1, 1, 1, 1, 1] - ) - # Conditional Mean - Ψ = CM( - scm, - outcome=:Y, - treatment=(T₁="A", T₂=1), - ) - indicator_fns = TMLE.indicator_fns(Ψ) - @test indicator_fns == Dict(("A", 1) => 1.) - indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) - @test indic_values == [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0] - # ATE - Ψ = ATE( - scm, - outcome=:Y, - treatment=(T₁=(case="A", control="B"), T₂=(control=0, case=1)), - ) - indicator_fns = TMLE.indicator_fns(Ψ) - @test indicator_fns == Dict( - ("A", 1) => 1.0, - ("B", 0) => -1.0 - ) - indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) - @test indic_values == [0.0, -1.0, 1.0, 0.0, 0.0, -1.0, 1.0, 0.0] - # 2-points IATE - Ψ = IATE( - scm, - outcome=:Y, - treatment=(T₁=(case="A", control="B"), T₂=(case=1, control=0)), - ) - indicator_fns = TMLE.indicator_fns(Ψ) - @test indicator_fns == Dict( - ("A", 1) => 1.0, - ("A", 0) => -1.0, - ("B", 1) => -1.0, - ("B", 0) => 1.0 - ) - indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) - @test indic_values == [-1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0] - # 3-points IATE - Ψ = IATE( - scm, - outcome=:Y, - treatment=(T₁=(case="A", control="B"), T₂=(case=1, control=0), T₃=(control="D", case="C")), - ) - indicator_fns = TMLE.indicator_fns(Ψ) - @test indicator_fns == Dict( - ("A", 1, "D") => -1.0, - ("A", 1, "C") => 1.0, - ("B", 0, "D") => -1.0, - ("B", 0, "C") => 1.0, - ("B", 1, "C") => -1.0, - ("A", 0, "D") => 1.0, - ("B", 1, "D") => 1.0, - ("A", 0, "C") => -1.0 - ) - indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) - @test indic_values == [-1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0] -end - @testset "Test expected_value" begin n = 100 X = MLJBase.table(rand(n, 3)) @@ -150,136 +78,6 @@ end @test !isordered(cfT.T₂) end -@testset "Test ps_lower_bound" begin - dataset = ( - T = categorical([1, 0, 1, 1, 0, 1, 1]), - Y = [1., 2., 3, 4, 5, 6, 7], - W = rand(7), - ) - Ψ = ATE( - outcome=:Y, - treatment=(T=(case=1, control=0),), - confounders=:W, - ) - max_lb = 0.1 - fit!(Ψ, dataset, verbosity=0) - # Nothing results in data adaptive lower bound no lower than max_lb - @test TMLE.ps_lower_bound(Ψ, nothing) == max_lb - @test TMLE.data_adaptive_ps_lower_bound(Ψ) == max_lb - @test TMLE.data_adaptive_ps_lower_bound(1000) == 0.02984228238321508 - # Otherwise use the provided threhsold provided it's lower than max_lb - @test TMLE.ps_lower_bound(Ψ, 1e-8) == 1.0e-8 - @test TMLE.ps_lower_bound(Ψ, 1) == 0.1 -end - -@testset "Test compute_offset, clever_covariate_and_weights: 1 treatment" begin - ### In all cases, models are "constant" models - ## First case: 1 Treamtent variable - Ψ = ATE( - outcome=:Y, - treatment=(T=(case="a", control="b"),), - confounders=[:W₁, :W₂, :W₃], - treatment_model = ConstantClassifier(), - outcome_model = ConstantRegressor() - ) - dataset = ( - T = categorical(["a", "b", "c", "a", "a", "b", "a"]), - Y = [1., 2., 3, 4, 5, 6, 7], - W₁ = rand(7), - W₂ = rand(7), - W₃ = rand(7) - ) - fit!(Ψ, dataset, verbosity=0) - - offset = TMLE.compute_offset(Ψ) - @test offset == repeat([mean(dataset.Y)], 7) - weighted_fluctuation = true - ps_lowerbound = 1e-8 - X, y = TMLE.get_outcome_datas(Ψ) - cov, w = TMLE.clever_covariate_and_weights(Ψ, X; - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation - ) - - @test cov == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] - @test w == [1.75, 3.5, 7.0, 1.75, 1.75, 3.5, 1.75] - - weighted_fluctuation = false - cov, w = TMLE.clever_covariate_and_weights(Ψ, X; - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation - ) - @test cov == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] - @test w == ones(7) -end - -@testset "Test compute_offset, clever_covariate_and_weights: 2 treatments" begin - Ψ = IATE( - outcome = :Y, - treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), - confounders=[:W₁, :W₂, :W₃], - treatment_model = ConstantClassifier(), - outcome_model = ConstantClassifier() - ) - dataset = ( - T₁ = categorical([1, 0, 0, 1, 1, 1, 0]), - T₂ = categorical([1, 1, 1, 1, 1, 0, 0]), - Y = categorical([1, 1, 1, 1, 0, 0, 0]), - W₁ = rand(7), - W₂ = rand(7), - W₃ = rand(7) - ) - ps_lowerbound = 1e-8 - weighted_fluctuation = false - - fit!(Ψ, dataset, verbosity=0) - - # Because the outcome is binary, the offset is the logit - offset = TMLE.compute_offset(Ψ) - @test offset == repeat([0.28768207245178085], 7) - X, y = TMLE.get_outcome_datas(Ψ) - cov, w = TMLE.clever_covariate_and_weights(Ψ, X; - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation - ) - @test cov ≈ [2.45, -3.266, -3.266, 2.45, 2.45, -6.125, 8.166] atol=1e-2 - @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] -end - -@testset "Test compute_offset, clever_covariate_and_weights: 3 treatments" begin - ## Third case: 3 Treatment variables - Ψ = IATE( - outcome =:Y, - treatment=(T₁=(case="a", control="b"), - T₂=(case=1, control=2), - T₃=(case=true, control=false)), - confounders=[:W], - treatment_model = ConstantClassifier(), - outcome_model = DeterministicConstantRegressor() - ) - dataset = ( - T₁ = categorical(["a", "a", "b", "b", "c", "b", "b"]), - T₂ = categorical([3, 2, 1, 1, 2, 2, 2], ordered=true), - T₃ = categorical([true, false, true, false, false, false, false], ordered=true), - Y = [1., 2., 3, 4, 5, 6, 7], - W = rand(7), - ) - ps_lowerbound = 1e-8 - weighted_fluctuation = false - - fit!(Ψ, dataset, verbosity=0) - - offset = TMLE.compute_offset(Ψ) - @test offset == repeat([4.0], 7) - X, y = TMLE.get_outcome_datas(Ψ) - cov, w = TMLE.clever_covariate_and_weights(Ψ, X; - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation - ) - @test cov ≈ [0, 8.575, -21.4375, 8.575, 0, -4.2875, -4.2875] atol=1e-3 - @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] -end - end; true \ No newline at end of file From cef6d723df52b6023b4030b9f8f96ce7c898f7fe Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 20 Aug 2023 12:46:33 +0100 Subject: [PATCH 052/151] more refactorization --- src/TMLE.jl | 9 +++------ src/counterfactual_mean_based/estimands.jl | 11 +++++++++++ .../{estimation.jl => estimators.jl} | 17 +++-------------- src/estimate.jl | 4 ---- test/composition.jl | 4 ++-- .../double_robustness_ate.jl | 6 +++--- .../double_robustness_iate.jl | 10 +++++----- .../{estimation.jl => estimators.jl} | 0 .../non_regression_test.jl | 2 +- test/runtests.jl | 3 ++- 10 files changed, 30 insertions(+), 36 deletions(-) rename src/counterfactual_mean_based/{estimation.jl => estimators.jl} (80%) rename test/counterfactual_mean_based/{estimation.jl => estimators.jl} (100%) diff --git a/src/TMLE.jl b/src/TMLE.jl index 37418d55..fe6face4 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -31,7 +31,7 @@ export AverageTreatmentEffect, ATE export InteractionAverageTreatmentEffect, IATE export AVAILABLE_ESTIMANDS export fit!, optimize_ordering, optimize_ordering! -export tmle!, ose!, initial +export tmle!, ose!, naive_plugin_estimate export var, estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose export TreatmentTransformer, with_encoder @@ -52,10 +52,7 @@ include("counterfactual_mean_based/estimands.jl") include("counterfactual_mean_based/offset_and_covariate.jl") include("counterfactual_mean_based/fluctuation.jl") include("counterfactual_mean_based/gradient.jl") -include("counterfactual_mean_based/estimation.jl") - - - +include("counterfactual_mean_based/estimators.jl") # ############################################################################# # PRECOMPILATION WORKLOAD @@ -131,7 +128,7 @@ function run_precompile_workload() composed_result = compose((x,y) -> x - y, tmle(result₂), tmle(result₁)) # Results manipulation - initial(result₃) + naive_plugin_estimate(result₃) OneSampleTTest(tmle(result₃)) OneSampleZTest(tmle(result₃)) OneSampleTTest(ose(result₃)) diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index c28855f0..b23a8bae 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -192,6 +192,17 @@ function check_parameter_against_scm(scm::SCM, outcome, treatment) end end +function MLJBase.fit!(Ψ::CMCompositeEstimand, dataset; adjustment_method=BackdoorAdjustment(), verbosity=1, force=false) + models_input_variables = get_models_input_variables(adjustment_method, Ψ) + for (variable, input_variables) in zip(keys(models_input_variables), models_input_variables) + fit!(Ψ.scm[variable], dataset; + input_variables=input_variables, + verbosity=verbosity, + force=force + ) + end +end + function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: CMCompositeEstimand param_string = string( name(T), diff --git a/src/counterfactual_mean_based/estimation.jl b/src/counterfactual_mean_based/estimators.jl similarity index 80% rename from src/counterfactual_mean_based/estimation.jl rename to src/counterfactual_mean_based/estimators.jl index ea545a98..424208ce 100644 --- a/src/counterfactual_mean_based/estimation.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -1,14 +1,3 @@ -function MLJBase.fit!(Ψ::CMCompositeEstimand, dataset; adjustment_method=BackdoorAdjustment(), verbosity=1, force=false) - models_input_variables = get_models_input_variables(adjustment_method, Ψ) - for (variable, input_variables) in zip(keys(models_input_variables), models_input_variables) - fit!(Ψ.scm[variable], dataset; - input_variables=input_variables, - verbosity=verbosity, - force=force - ) - end -end - """ tmle!(Ψ::CMCompositeEstimand, dataset; adjustment_method=BackdoorAdjustment(), @@ -59,10 +48,9 @@ function tmle!(Ψ::CMCompositeEstimand, dataset; ) fit!(Q, verbosity=verbosity-1) # Estimation results after TMLE - Ψ̂⁰ = mean(counterfactual_aggregate(Ψ, Q⁰)) IC, Ψ̂ = gradient_and_estimate(Ψ, Q; ps_lowerbound=ps_lowerbound) verbosity >= 1 && @info "Done." - return TMLEstimate(Ψ̂⁰, Ψ̂, IC), Q + return TMLEstimate(Ψ̂, IC), Q end function ose!(Ψ::CMCompositeEstimand, dataset; @@ -86,6 +74,7 @@ function ose!(Ψ::CMCompositeEstimand, dataset; # Gradient and estimate IC, Ψ̂ = gradient_and_estimate(Ψ, Q; ps_lowerbound=ps_lowerbound) verbosity >= 1 && @info "Done." - return OSEstimate(Ψ̂, Ψ̂ + mean(IC), IC), Q + return OSEstimate(Ψ̂ + mean(IC), IC), Q end +naive_plugin_estimate(Ψ::CMCompositeEstimand) = mean(counterfactual_aggregate(Ψ, get_outcome_model(Ψ))) diff --git a/src/estimate.jl b/src/estimate.jl index abfb7996..7976be0d 100644 --- a/src/estimate.jl +++ b/src/estimate.jl @@ -2,13 +2,11 @@ abstract type EICEstimate end struct TMLEstimate{T<:AbstractFloat} <: EICEstimate - Ψ̂⁰::T Ψ̂::T IC::Vector{T} end struct OSEstimate{T<:AbstractFloat} <: EICEstimate - Ψ̂⁰::T Ψ̂::T IC::Vector{T} end @@ -36,8 +34,6 @@ function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate) pretty_table(io, data;header=["Estimate", "95% Confidence Interval", "P-value"]) end -initial(est::EICEstimate) = est.Ψ̂⁰ - """ Distributions.estimate(r::EICEstimate) diff --git a/test/composition.jl b/test/composition.jl index fa7575de..a44e2755 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -25,8 +25,8 @@ end @testset "Test cov" begin n = 10 X = rand(n, 2) - ER₁ = TMLE.TMLEstimate(0., 1., X[:, 1]) - ER₂ = TMLE.TMLEstimate(0., 0., X[:, 2]) + ER₁ = TMLE.TMLEstimate(1., X[:, 1]) + ER₂ = TMLE.TMLEstimate(0., X[:, 2]) Σ = cov(ER₁, ER₂) @test size(Σ) == (2, 2) @test Σ == cov(X) diff --git a/test/counterfactual_mean_based/double_robustness_ate.jl b/test/counterfactual_mean_based/double_robustness_ate.jl index 915bfc0e..799758bb 100644 --- a/test/counterfactual_mean_based/double_robustness_ate.jl +++ b/test/counterfactual_mean_based/double_robustness_ate.jl @@ -130,7 +130,7 @@ end test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test initial(tmle_result) == 0 + @test naive_plugin_estimate(Ψ) == 0 # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> LinearRegressor() @@ -159,7 +159,7 @@ end test_mean_inf_curve_almost_zero(tmle_result; atol=1e-6) @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test initial(tmle_result) == 0 + @test naive_plugin_estimate(Ψ) == 0 # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> LogisticClassifier(lambda=0) @@ -189,7 +189,7 @@ end test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test initial(tmle_result) == 0 + @test naive_plugin_estimate(Ψ) == 0 # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> LinearRegressor() diff --git a/test/counterfactual_mean_based/double_robustness_iate.jl b/test/counterfactual_mean_based/double_robustness_iate.jl index c04955bb..1c2b2427 100644 --- a/test/counterfactual_mean_based/double_robustness_iate.jl +++ b/test/counterfactual_mean_based/double_robustness_iate.jl @@ -193,7 +193,7 @@ end test_mean_inf_curve_almost_zero(tmle_result; atol=1e-9) @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test initial(tmle_result) == 0 + @test naive_plugin_estimate(Ψ) == 0 # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> LogisticClassifier(lambda=0) @@ -206,7 +206,7 @@ end test_mean_inf_curve_almost_zero(tmle_result; atol=1e-9) @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test initial(tmle_result) ≈ -0.0 atol=1e-1 + @test naive_plugin_estimate(Ψ) ≈ -0.0 atol=1e-1 end @testset "Test Double Robustness IATE on continuous_outcome_binary_treatment_pb" begin @@ -227,7 +227,7 @@ end test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test initial(tmle_result) == 0 + @test naive_plugin_estimate(Ψ) == 0 # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> cont_interacter @@ -258,7 +258,7 @@ end test_fluct_decreases_risk(Ψ, fluctuation_mach) @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test initial(tmle_result) == 0 + @test naive_plugin_estimate(Ψ) == 0 # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> cat_interacter @@ -270,7 +270,7 @@ end test_fluct_decreases_risk(Ψ, fluctuation_mach) @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) # The initial estimate is far away - @test initial(tmle_result) ≈ -0.02 atol=1e-2 + @test naive_plugin_estimate(Ψ) ≈ -0.02 atol=1e-2 end diff --git a/test/counterfactual_mean_based/estimation.jl b/test/counterfactual_mean_based/estimators.jl similarity index 100% rename from test/counterfactual_mean_based/estimation.jl rename to test/counterfactual_mean_based/estimators.jl diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index 18ca4188..509a4780 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -28,7 +28,7 @@ using MLJGLMInterface tmle_result, fluctuation_mach = tmle!(Ψ, dataset, ps_lowerbound=0.025, verbosity=0) # TMLE @test estimate(tmle_result) ≈ -0.185533 atol = 1e-6 - @test initial(tmle_result) ≈ -0.150078 atol = 1e-6 + @test naive_plugin_estimate(Ψ) ≈ -0.150078 atol = 1e-6 l, u = confint(OneSampleTTest(tmle_result)) @test l ≈ -0.279246 atol = 1e-6 @test u ≈ -0.091821 atol = 1e-6 diff --git a/test/runtests.jl b/test/runtests.jl index a10bcb65..1767cd0c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,11 +7,12 @@ using Test @test include("composition.jl") @test include("treatment_transformer.jl") @test include("counterfactual_mean_based/non_regression_test.jl") + @test include("counterfactual_mean_based/offset_and_covariate.jl") @test include("counterfactual_mean_based/gradient.jl") @test include("counterfactual_mean_based/fluctuation.jl") @test include("counterfactual_mean_based/double_robustness_ate.jl") @test include("counterfactual_mean_based/double_robustness_iate.jl") @test include("counterfactual_mean_based/3points_interactions.jl") - @test include("counterfactual_mean_based/estimation.jl") + @test include("counterfactual_mean_based/estimators.jl") @test include("counterfactual_mean_based/estimands.jl") end \ No newline at end of file From c0ae075bb1fa7633657177d9d3d11b648231e42f Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 20 Aug 2023 13:39:29 +0100 Subject: [PATCH 053/151] fix ate double robustness test --- .../double_robustness_ate.jl | 30 +++++++++++++++---- test/helper_fns.jl | 4 +-- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/test/counterfactual_mean_based/double_robustness_ate.jl b/test/counterfactual_mean_based/double_robustness_ate.jl index 799758bb..71519650 100644 --- a/test/counterfactual_mean_based/double_robustness_ate.jl +++ b/test/counterfactual_mean_based/double_robustness_ate.jl @@ -125,10 +125,12 @@ end scm.T.model = LogisticClassifier(lambda=0) tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) + ose_result, Q⁰ = ose!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) + test_coverage(ose_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away @test naive_plugin_estimate(Ψ) == 0 @@ -137,7 +139,9 @@ end scm.T.model = ConstantClassifier() tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) + ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) + test_coverage(ose_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end @@ -154,10 +158,12 @@ end scm.T.model = LogisticClassifier(lambda=0) tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) + ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) + test_coverage(ose_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-6) - @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away @test naive_plugin_estimate(Ψ) == 0 @@ -166,7 +172,9 @@ end scm.T.model = ConstantClassifier() tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) + ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) + test_coverage(ose_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-6) end @@ -184,10 +192,12 @@ end scm.T.model = LogisticClassifier(lambda=0) tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) + ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) + test_coverage(ose_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away @test naive_plugin_estimate(Ψ) == 0 @@ -196,7 +206,9 @@ end scm.T.model = ConstantClassifier() tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) + ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, Ψ₀) + test_coverage(ose_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end @@ -215,17 +227,21 @@ end scm.T₂.model = LogisticClassifier(lambda=0) tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) + ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₁) + test_coverage(ose_result, ATE₁₁₋₀₁) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> LinearRegressor() scm.T₁.model = ConstantClassifier() scm.T₂.model = ConstantClassifier() tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) + ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₁) + test_coverage(ose_result, ATE₁₁₋₀₁) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) @@ -236,10 +252,12 @@ end treatment = (T₁=(case=1., control=0.), T₂=(case=1., control=0.)), ) tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) + ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₀) + test_coverage(ose_result, ATE₁₁₋₀₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> MLJModels.DeterministicConstantRegressor() @@ -247,7 +265,9 @@ end scm.T₂.model = LogisticClassifier(lambda=0) tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) + ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₀) + test_coverage(ose_result, ATE₁₁₋₀₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end diff --git a/test/helper_fns.jl b/test/helper_fns.jl index 24a6b3af..741940a2 100644 --- a/test/helper_fns.jl +++ b/test/helper_fns.jl @@ -40,7 +40,7 @@ function test_coverage(result::TMLE.EICEstimate, Ψ₀) lb, ub = confint(OneSampleTTest(result)) @test lb ≤ Ψ₀ ≤ ub # OneStep - lb, ub = confint(OneSampleTTest(result)) + lb, ub = confint(OneSampleZTest(result)) @test lb ≤ Ψ₀ ≤ ub end @@ -57,4 +57,4 @@ test_mean_inf_curve_almost_zero(tmle_result::TMLE.EICEstimate; atol=1e-10) = @te This cqnnot be guaranteed in general since a well specified maximum likelihood estimator also solves the score equation. """ -test_fluct_mean_inf_curve_lower_than_initial(tmle_result::TMLE.EICEstimate) = @test abs(mean(tmle_result.IC)) < abs(mean(tmle_result.IC)) \ No newline at end of file +test_fluct_mean_inf_curve_lower_than_initial(tmle_result::TMLE.TMLEstimate, ose_result::TMLE.OSEstimate) = @test abs(mean(tmle_result.IC)) < abs(mean(ose_result.IC)) \ No newline at end of file From 7e82c70a1d3d03d4d71465580f17c3cc88db2399 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 20 Aug 2023 13:42:59 +0100 Subject: [PATCH 054/151] fix double robustness iate --- .../double_robustness_iate.jl | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/test/counterfactual_mean_based/double_robustness_iate.jl b/test/counterfactual_mean_based/double_robustness_iate.jl index 1c2b2427..d94a3a8e 100644 --- a/test/counterfactual_mean_based/double_robustness_iate.jl +++ b/test/counterfactual_mean_based/double_robustness_iate.jl @@ -188,10 +188,12 @@ end scm.T₂.model = LogisticClassifier(lambda=0) tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0); + ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) + test_coverage(ose_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-9) - @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away @test naive_plugin_estimate(Ψ) == 0 @@ -201,10 +203,12 @@ end scm.T₂.model = ConstantClassifier() tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0); + ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) + test_coverage(ose_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-9) - @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away @test naive_plugin_estimate(Ψ) ≈ -0.0 atol=1e-1 end @@ -222,10 +226,12 @@ end scm.T₂.model = LogisticClassifier(lambda=0) tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) + ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) + test_coverage(ose_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away @test naive_plugin_estimate(Ψ) == 0 @@ -235,7 +241,9 @@ end scm.T₂.model = ConstantClassifier() tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) + ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) + test_coverage(ose_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end @@ -254,9 +262,11 @@ end scm.T₂.model = LogisticClassifier(lambda=0) tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0); + ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) + test_coverage(ose_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) - @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away @test naive_plugin_estimate(Ψ) == 0 @@ -266,9 +276,11 @@ end scm.T₂.model = ConstantClassifier() tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0); + ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0); test_coverage(tmle_result, Ψ₀) + test_coverage(ose_result, Ψ₀) test_fluct_decreases_risk(Ψ, fluctuation_mach) - @test_skip test_fluct_mean_inf_curve_lower_than_initial(tmle_result) + test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away @test naive_plugin_estimate(Ψ) ≈ -0.02 atol=1e-2 end From 0ada1dc69c19e6990b73fa6fa4c2e3093a8a5e0a Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 20 Aug 2023 14:40:58 +0100 Subject: [PATCH 055/151] add caching of covariate --- src/TMLE.jl | 2 +- src/counterfactual_mean_based/fluctuation.jl | 15 +++++++++++---- src/counterfactual_mean_based/gradient.jl | 5 ++--- .../offset_and_covariate.jl | 18 ++++++++++++++++++ test/composition.jl | 2 +- test/counterfactual_mean_based/fluctuation.jl | 3 +-- 6 files changed, 34 insertions(+), 11 deletions(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index fe6face4..cb758fd1 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -49,8 +49,8 @@ include("treatment_transformer.jl") include("adjustment.jl") include("counterfactual_mean_based/estimands.jl") -include("counterfactual_mean_based/offset_and_covariate.jl") include("counterfactual_mean_based/fluctuation.jl") +include("counterfactual_mean_based/offset_and_covariate.jl") include("counterfactual_mean_based/gradient.jl") include("counterfactual_mean_based/estimators.jl") diff --git a/src/counterfactual_mean_based/fluctuation.jl b/src/counterfactual_mean_based/fluctuation.jl index b33d965f..9c411753 100644 --- a/src/counterfactual_mean_based/fluctuation.jl +++ b/src/counterfactual_mean_based/fluctuation.jl @@ -44,7 +44,7 @@ function clever_covariate_offset_and_weights(Ψ, Q, X; ps_lowerbound=1e-8, weighted_fluctuation=false ) - offset = TMLE.compute_offset(MLJBase.predict(Q, X)) + offset = compute_offset(MLJBase.predict(Q, X)) covariate, weights = TMLE.clever_covariate_and_weights( Ψ, X; ps_lowerbound=ps_lowerbound, @@ -70,10 +70,17 @@ function MLJBase.fit(model::FluctuationModel, verbosity, X, y) weights, ) MLJBase.fit!(mach, verbosity=verbosity) - return mach, nothing, nothing + + fitresult = ( + one_dimensional_path = mach, + ) + cache = ( + weighted_covariate = clever_covariate_and_offset.covariate .* weights, + ) + return fitresult, cache, nothing end -function MLJBase.predict(model::FluctuationModel, one_dimensional_path, X) +function MLJBase.predict(model::FluctuationModel, fitresult, X) Ψ = model.Ψ Q = get_outcome_model(Ψ) clever_covariate_and_offset, weights = @@ -82,5 +89,5 @@ function MLJBase.predict(model::FluctuationModel, one_dimensional_path, X) ps_lowerbound=model.ps_lowerbound, weighted_fluctuation=model.weighted ) - return MLJBase.predict(one_dimensional_path, clever_covariate_and_offset) + return MLJBase.predict(fitresult.one_dimensional_path, clever_covariate_and_offset) end \ No newline at end of file diff --git a/src/counterfactual_mean_based/gradient.jl b/src/counterfactual_mean_based/gradient.jl index d3615e66..e4dc9fe7 100644 --- a/src/counterfactual_mean_based/gradient.jl +++ b/src/counterfactual_mean_based/gradient.jl @@ -32,7 +32,6 @@ end """ ∇W(ctf_agg, Ψ̂) = ctf_agg .- Ψ̂ - """ gradient_Y_X(cache) @@ -42,9 +41,9 @@ This part of the gradient is evaluated on the original dataset. All quantities h """ function ∇YX(Ψ::CMCompositeEstimand, Q::Machine; ps_lowerbound=1e-8) X, y = get_outcome_datas(Ψ) - H, weights = clever_covariate_and_weights(Ψ, X; ps_lowerbound=ps_lowerbound) + H = weighted_covariate(Q, Ψ, X; ps_lowerbound=ps_lowerbound) y = float(y) - gradient_Y_X_fluct = H .* weights .* (y .- expected_value(MLJBase.predict(Q))) + gradient_Y_X_fluct = H .* (y .- expected_value(MLJBase.predict(Q, X))) return gradient_Y_X_fluct end diff --git a/src/counterfactual_mean_based/offset_and_covariate.jl b/src/counterfactual_mean_based/offset_and_covariate.jl index 5d344ccd..8f8b24eb 100644 --- a/src/counterfactual_mean_based/offset_and_covariate.jl +++ b/src/counterfactual_mean_based/offset_and_covariate.jl @@ -75,3 +75,21 @@ function clever_covariate_and_weights(Ψ::CMCompositeEstimand, X; ps_lowerbound= indic_vals .*= weights return indic_vals, ones(size(weights, 1)) end + +""" + weighted_covariate(Q::Machine{<:Fluctuation}, Ψ, X; ps_lowerbound=1e-8) + +If Q is a fluctuation and caches data, the covariate has already been computed and can be retrieved. +""" +weighted_covariate(Q::Machine{<:FluctuationModel, true}, Ψ, X; ps_lowerbound=1e-8) = Q.cache.weighted_covariate + + +""" + weighted_covariate(Q::Machine, Ψ, X; ps_lowerbound=1e-8) + +Computes the weighted covariate for the gradient. +""" +function weighted_covariate(Q::Machine, Ψ, X; ps_lowerbound=1e-8) + H, weights = clever_covariate_and_weights(Ψ, X; ps_lowerbound=ps_lowerbound) + return H .* weights +end diff --git a/test/composition.jl b/test/composition.jl index a44e2755..88cd0e4f 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -83,7 +83,7 @@ end ose_confint = collect(confint(OneSampleTTest(ATE_ose_result₁₀))) @test ose_confint ≈ composed_confint atol=1e-4 # Z Test - composed_confint = collect(confint(OneSampleZTest(CM_result_composed_tmle))) + composed_confint = collect(confint(OneSampleZTest(CM_result_composed_ose))) ose_confint = collect(confint(OneSampleZTest(ATE_ose_result₁₀))) @test ose_confint ≈ composed_confint atol=1e-4 # Variance diff --git a/test/counterfactual_mean_based/fluctuation.jl b/test/counterfactual_mean_based/fluctuation.jl index 13fe023a..b968bd81 100644 --- a/test/counterfactual_mean_based/fluctuation.jl +++ b/test/counterfactual_mean_based/fluctuation.jl @@ -18,8 +18,7 @@ function test_fluctuation_fit(Ψ, ps_lowerbound, weighted_fluctuation) ps_lowerbound=ps_lowerbound, weighted_fluctuation=weighted_fluctuation ) - @test Q.fitresult.data[1] == Xfluct - @test Q.fitresult.data[3] == weights + @test Q.cache.weighted_covariate == Xfluct.covariate .* weights end From 5596a08a4b1773975bfc6fd4af589aa802e7697c Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 20 Aug 2023 14:47:08 +0100 Subject: [PATCH 056/151] cache training expected value --- src/counterfactual_mean_based/fluctuation.jl | 6 +++++- src/counterfactual_mean_based/gradient.jl | 2 +- src/utils.jl | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/counterfactual_mean_based/fluctuation.jl b/src/counterfactual_mean_based/fluctuation.jl index 9c411753..f5224f35 100644 --- a/src/counterfactual_mean_based/fluctuation.jl +++ b/src/counterfactual_mean_based/fluctuation.jl @@ -40,6 +40,9 @@ The GLM models require inputs of the same type fluctuation_input(covariate::AbstractVector{T1}, offset::AbstractVector{T2}) where {T1, T2} = (covariate=covariate, offset=convert(Vector{T1}, offset)) + +training_expected_value(Q::Machine{<:FluctuationModel}) = Q.cache.training_expected_value + function clever_covariate_offset_and_weights(Ψ, Q, X; ps_lowerbound=1e-8, weighted_fluctuation=false @@ -69,13 +72,14 @@ function MLJBase.fit(model::FluctuationModel, verbosity, X, y) y, weights, ) - MLJBase.fit!(mach, verbosity=verbosity) + fit!(mach, verbosity=verbosity) fitresult = ( one_dimensional_path = mach, ) cache = ( weighted_covariate = clever_covariate_and_offset.covariate .* weights, + training_expected_value = expected_value(predict(mach)) ) return fitresult, cache, nothing end diff --git a/src/counterfactual_mean_based/gradient.jl b/src/counterfactual_mean_based/gradient.jl index e4dc9fe7..230dc59e 100644 --- a/src/counterfactual_mean_based/gradient.jl +++ b/src/counterfactual_mean_based/gradient.jl @@ -43,7 +43,7 @@ function ∇YX(Ψ::CMCompositeEstimand, Q::Machine; ps_lowerbound=1e-8) X, y = get_outcome_datas(Ψ) H = weighted_covariate(Q, Ψ, X; ps_lowerbound=ps_lowerbound) y = float(y) - gradient_Y_X_fluct = H .* (y .- expected_value(MLJBase.predict(Q, X))) + gradient_Y_X_fluct = H .* (y .- training_expected_value(Q)) return gradient_Y_X_fluct end diff --git a/src/utils.jl b/src/utils.jl index c3e1470d..1457309f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -41,6 +41,9 @@ expected_value(ŷ::UnivariateFiniteVector{Multiclass{2}}) = pdf.(ŷ, levels(firs expected_value(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = mean.(ŷ) expected_value(ŷ::AbstractVector{<:Real}) = ŷ +training_expected_value(Q::Machine) = expected_value(predict(Q)) + + function counterfactualTreatment(vals, T) Tnames = Tables.columnnames(T) n = nrows(T) From ccc7fba8ec64484d39fc10126f643969907bc659 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 20 Aug 2023 14:47:58 +0100 Subject: [PATCH 057/151] syntax adjustment --- src/counterfactual_mean_based/fluctuation.jl | 2 +- src/counterfactual_mean_based/offset_and_covariate.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/counterfactual_mean_based/fluctuation.jl b/src/counterfactual_mean_based/fluctuation.jl index f5224f35..a7e3cc6d 100644 --- a/src/counterfactual_mean_based/fluctuation.jl +++ b/src/counterfactual_mean_based/fluctuation.jl @@ -41,7 +41,7 @@ fluctuation_input(covariate::AbstractVector{T1}, offset::AbstractVector{T2}) whe (covariate=covariate, offset=convert(Vector{T1}, offset)) -training_expected_value(Q::Machine{<:FluctuationModel}) = Q.cache.training_expected_value +training_expected_value(Q::Machine{<:FluctuationModel, }) = Q.cache.training_expected_value function clever_covariate_offset_and_weights(Ψ, Q, X; ps_lowerbound=1e-8, diff --git a/src/counterfactual_mean_based/offset_and_covariate.jl b/src/counterfactual_mean_based/offset_and_covariate.jl index 8f8b24eb..4a903db0 100644 --- a/src/counterfactual_mean_based/offset_and_covariate.jl +++ b/src/counterfactual_mean_based/offset_and_covariate.jl @@ -81,7 +81,7 @@ end If Q is a fluctuation and caches data, the covariate has already been computed and can be retrieved. """ -weighted_covariate(Q::Machine{<:FluctuationModel, true}, Ψ, X; ps_lowerbound=1e-8) = Q.cache.weighted_covariate +weighted_covariate(Q::Machine{<:FluctuationModel, }, args...; kwargs...) = Q.cache.weighted_covariate """ From b982cdd73badac54fc5221ac8a856ec31c8c8eda Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 20 Aug 2023 14:59:01 +0100 Subject: [PATCH 058/151] add Statistics to test deps --- test/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/Project.toml b/test/Project.toml index c9b22e77..065a040d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,6 +13,7 @@ MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" From b518415503b8c653a3dd20377df7375573fd13c3 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 20 Aug 2023 15:25:13 +0100 Subject: [PATCH 059/151] add precompilation again --- src/TMLE.jl | 40 ++++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index cb758fd1..cd278bb2 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -108,31 +108,47 @@ function run_precompile_workload() outcome =:Y₁, treatment=(T₁=true,), ) - result₁, fluctuation = tmle!(Ψ₁, dataset, verbosity=0) + tmle_result₁, fluctuation = tmle!(Ψ₁, dataset, verbosity=0) + OneSampleTTest(tmle_result₁) + OneSampleZTest(tmle_result₁) + ose_result₁, fluctuation = ose!(Ψ₁, dataset, verbosity=0) + OneSampleTTest(ose_result₁) + OneSampleZTest(ose_result₁) + + naive_plugin_estimate(Ψ₁) + # ATE Ψ₂ = ATE( scm, outcome=:Y₂, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)) ) - result₂, fluctuation = tmle!(Ψ₂, dataset, verbosity=0) + tmle_result₂, fluctuation = tmle!(Ψ₂, dataset, verbosity=0) + OneSampleTTest(tmle_result₂) + OneSampleZTest(tmle_result₂) + + ose_result₂, fluctuation = ose!(Ψ₂, dataset, verbosity=0) + OneSampleTTest(ose_result₂) + OneSampleZTest(ose_result₂) + naive_plugin_estimate(Ψ₂) + # IATE Ψ₃ = IATE( scm, outcome=:Y₂, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)) ) - result₃, fluctuation = tmle!(Ψ₃, dataset, verbosity=0) + tmle_result₃, fluctuation = tmle!(Ψ₃, dataset, verbosity=0) + OneSampleTTest(tmle_result₃) + OneSampleZTest(tmle_result₃) + + ose_result₃, fluctuation = ose!(Ψ₃, dataset, verbosity=0) + OneSampleTTest(ose_result₃) + OneSampleZTest(ose_result₃) + naive_plugin_estimate(Ψ₃) # Composition - composed_result = compose((x,y) -> x - y, tmle(result₂), tmle(result₁)) - - # Results manipulation - naive_plugin_estimate(result₃) - OneSampleTTest(tmle(result₃)) - OneSampleZTest(tmle(result₃)) - OneSampleTTest(ose(result₃)) - OneSampleZTest(ose(result₃)) + composed_result = compose((x,y) -> x - y, tmle_result₂, tmle_result₁) OneSampleTTest(composed_result) OneSampleZTest(composed_result) @@ -141,6 +157,6 @@ function run_precompile_workload() end -# run_precompile_workload() +run_precompile_workload() end From 14487c79e17fa0adef6680ad8156216b7cd2455b Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 20 Aug 2023 15:31:43 +0100 Subject: [PATCH 060/151] add naive_plugin_estimate! --- src/TMLE.jl | 2 +- src/counterfactual_mean_based/estimators.jl | 17 +++++++++++++++++ .../double_robustness_ate.jl | 4 ++-- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index cd278bb2..e0363f28 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -31,7 +31,7 @@ export AverageTreatmentEffect, ATE export InteractionAverageTreatmentEffect, IATE export AVAILABLE_ESTIMANDS export fit!, optimize_ordering, optimize_ordering! -export tmle!, ose!, naive_plugin_estimate +export tmle!, ose!, naive_plugin_estimate!, naive_plugin_estimate export var, estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose export TreatmentTransformer, with_encoder diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 424208ce..453971e3 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -78,3 +78,20 @@ function ose!(Ψ::CMCompositeEstimand, dataset; end naive_plugin_estimate(Ψ::CMCompositeEstimand) = mean(counterfactual_aggregate(Ψ, get_outcome_model(Ψ))) + +function naive_plugin_estimate!(Ψ::CMCompositeEstimand, dataset; + adjustment_method=BackdoorAdjustment(), + verbosity=1, + force=false) + # Check the estimand against the dataset + check_treatment_levels(Ψ, dataset) + # Initial fit of the SCM's equations + verbosity >= 1 && @info "Fitting the required equations..." + fit!(Ψ, dataset; + adjustment_method=adjustment_method, + verbosity=verbosity, + force=force + ) + return naive_plugin_estimate(Ψ) +end + diff --git a/test/counterfactual_mean_based/double_robustness_ate.jl b/test/counterfactual_mean_based/double_robustness_ate.jl index 71519650..dc0fe023 100644 --- a/test/counterfactual_mean_based/double_robustness_ate.jl +++ b/test/counterfactual_mean_based/double_robustness_ate.jl @@ -132,7 +132,7 @@ end test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away - @test naive_plugin_estimate(Ψ) == 0 + @test naive_plugin_estimate!(Ψ, dataset, verbosity=0) == 0 # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> LinearRegressor() @@ -165,7 +165,7 @@ end test_mean_inf_curve_almost_zero(tmle_result; atol=1e-6) test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away - @test naive_plugin_estimate(Ψ) == 0 + @test naive_plugin_estimate!(Ψ, dataset, verbosity=0) == 0 # When Q is well specified but G is misspecified scm.Y.model = TreatmentTransformer() |> LogisticClassifier(lambda=0) From 7c36c4e6468660513e34a0b9db9c880c0aa7a90c Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 20 Aug 2023 15:36:34 +0100 Subject: [PATCH 061/151] rename conditional mean to counterfactual mean --- docs/src/user_guide/estimands.md | 2 +- docs/src/walk_through.md | 2 +- src/TMLE.jl | 2 +- src/counterfactual_mean_based/estimands.jl | 10 +++++----- test/composition.jl | 4 ++-- test/counterfactual_mean_based/estimands.jl | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/src/user_guide/estimands.md b/docs/src/user_guide/estimands.md index b01cae47..412c0cae 100644 --- a/docs/src/user_guide/estimands.md +++ b/docs/src/user_guide/estimands.md @@ -16,7 +16,7 @@ At the moment, most of the work in this package has been focused on estimands th In what follows, ``P`` is a probability distribution generating an outcome ``Y``, a random vector of "treatment" variables ``\textbf{T}`` and a random vector of "confounding" variables ``\textbf{W}``. -## The Interventional Conditional Mean (CM) +## The Interventional Counterfactual Mean (CM) - Causal Question: diff --git a/docs/src/walk_through.md b/docs/src/walk_through.md index e0179176..9ae4f66c 100644 --- a/docs/src/walk_through.md +++ b/docs/src/walk_through.md @@ -91,7 +91,7 @@ AVAILABLE_ESTIMANDS At the moment there are 3 main estimand types we can estimate in TMLE.jl, we provide below a few examples. -- The Interventional Conditional Mean: +- The Interventional Counterfactual Mean: ```@example walk-through cm = CM( diff --git a/src/TMLE.jl b/src/TMLE.jl index e0363f28..a7f47b54 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -26,7 +26,7 @@ import AbstractDifferentiation as AD export SE, StructuralEquation export StructuralCausalModel, SCM, StaticConfoundedModel export setmodel!, equations, reset!, parents -export ConditionalMean, CM +export CounterfactualMean, CM export AverageTreatmentEffect, ATE export InteractionAverageTreatmentEffect, IATE export AVAILABLE_ESTIMANDS diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index b23a8bae..dbd85903 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -1,9 +1,9 @@ ##################################################################### -### Conditional Mean ### +### Counterfactual Mean ### ##################################################################### """ -# Conditional Mean / CM +# Counterfactual Mean / CM ## Definition @@ -27,17 +27,17 @@ ``` """ -struct ConditionalMean <: Estimand +struct CounterfactualMean <: Estimand scm::StructuralCausalModel outcome::Symbol treatment::NamedTuple - function ConditionalMean(scm, outcome, treatment) + function CounterfactualMean(scm, outcome, treatment) check_parameter_against_scm(scm, outcome, treatment) return new(scm, outcome, treatment) end end -const CM = ConditionalMean +const CM = CounterfactualMean indicator_fns(Ψ::CM) = Dict(values(Ψ.treatment) => 1.) diff --git a/test/composition.jl b/test/composition.jl index 88cd0e4f..a78e2c50 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -34,7 +34,7 @@ end @testset "Test composition CM(1) - CM(0) = ATE(1,0)" begin dataset, scm = make_dataset_and_scm(;n=1000) - # Conditional Mean T = 1 + # Counterfactual Mean T = 1 CM₁ = CM( scm, outcome = :Y, @@ -42,7 +42,7 @@ end ) CM_tmle_result₁, _ = tmle!(CM₁, dataset, verbosity=0) CM_ose_result₁, _ = ose!(CM₁, dataset, verbosity=0) - # Conditional Mean T = 0 + # Counterfactual Mean T = 0 CM₀ = CM( scm, outcome = :Y, diff --git a/test/counterfactual_mean_based/estimands.jl b/test/counterfactual_mean_based/estimands.jl index 6be5e712..b610a220 100644 --- a/test/counterfactual_mean_based/estimands.jl +++ b/test/counterfactual_mean_based/estimands.jl @@ -253,7 +253,7 @@ end T₃ = ["C", "C", "C", "C", "D", "D", "D", "D"], Y = [1, 1, 1, 1, 1, 1, 1, 1] ) - # Conditional Mean + # Counterfactual Mean Ψ = CM( scm, outcome=:Y, From 1e174e4886823442c9e7f5a9b2ac3d9a0a144b74 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 30 Aug 2023 13:58:40 +0100 Subject: [PATCH 062/151] add conditional distributions --- src/TMLE.jl | 5 +- src/adjustment.jl | 34 +-- src/counterfactual_mean_based/estimands.jl | 24 +- src/counterfactual_mean_based/estimators.jl | 66 +++++ src/distribution_factors.jl | 50 ++++ src/estimands.jl | 5 + src/scm.jl | 256 ++++++++------------ test/counterfactual_mean_based/estimands.jl | 81 ++++--- test/distribution_factors.jl | 69 ++++++ test/runtests.jl | 1 + test/scm.jl | 124 +++------- 11 files changed, 413 insertions(+), 302 deletions(-) create mode 100644 src/distribution_factors.jl create mode 100644 test/distribution_factors.jl diff --git a/src/TMLE.jl b/src/TMLE.jl index a7f47b54..6a56f4bd 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -25,7 +25,9 @@ import AbstractDifferentiation as AD export SE, StructuralEquation export StructuralCausalModel, SCM, StaticConfoundedModel +export setequation!, getequation, get_conditional_distribution, set_conditional_distribution! export setmodel!, equations, reset!, parents +export ConditionalDistribution export CounterfactualMean, CM export AverageTreatmentEffect, ATE export InteractionAverageTreatmentEffect, IATE @@ -42,6 +44,7 @@ export BackdoorAdjustment # ############################################################################# include("utils.jl") +include("distribution_factors.jl") include("scm.jl") include("estimands.jl") include("estimate.jl") @@ -157,6 +160,6 @@ function run_precompile_workload() end -run_precompile_workload() +# run_precompile_workload() end diff --git a/src/adjustment.jl b/src/adjustment.jl index 8aa4cea5..3853ca05 100644 --- a/src/adjustment.jl +++ b/src/adjustment.jl @@ -16,31 +16,13 @@ in order to improve inference. """ BackdoorAdjustment(;outcome_extra=[]) = BackdoorAdjustment(outcome_extra) -""" - outcome_input_variables(treatments, confounding_variables) +confounders(::BackdoorAdjustment, Ψ::Estimand) = + union((parents(Ψ.scm, treatment) for treatment in treatments(Ψ))...) -This is defined as the set of variable corresponding to: +outcome_parents(adjustment_method::BackdoorAdjustment, Ψ::Estimand) = + Set(vcat( + treatments(Ψ), + confounders(adjustment_method, Ψ)..., + adjustment_method.outcome_extra + )) -- Treatment variables -- Confounding variables -- Extra covariates -""" -outcome_input_variables(treatments, confounding_variables, extra_covariates) = unique(vcat( - treatments, - confounding_variables, - extra_covariates -)) - -function get_models_input_variables(adjustment_method::BackdoorAdjustment, Ψ::Estimand) - models_inputs = [] - for treatment in treatments(Ψ) - push!(models_inputs, parents(Ψ.scm, treatment)) - end - unique_confounders = unique(vcat(models_inputs...)) - push!( - models_inputs, - outcome_input_variables(treatments(Ψ), unique_confounders, adjustment_method.outcome_extra) - ) - variables = Tuple(vcat(treatments(Ψ), outcome(Ψ))) - return NamedTuple{variables}(models_inputs) -end diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index dbd85903..ae62dec7 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -168,7 +168,10 @@ for typename in CMCompositeTypenames covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, outcome_model = with_encoder(LinearRegressor()), treatment_model = LinearBinaryClassifier()) - scm = StaticConfoundedModel([outcome], collect(keys(treatment)), confounders; + scm = StaticConfoundedModel( + [outcome], + collect(keys(treatment)), + confounders; covariates=covariates, outcome_model=outcome_model, treatment_model=treatment_model @@ -192,11 +195,20 @@ function check_parameter_against_scm(scm::SCM, outcome, treatment) end end -function MLJBase.fit!(Ψ::CMCompositeEstimand, dataset; adjustment_method=BackdoorAdjustment(), verbosity=1, force=false) - models_input_variables = get_models_input_variables(adjustment_method, Ψ) - for (variable, input_variables) in zip(keys(models_input_variables), models_input_variables) - fit!(Ψ.scm[variable], dataset; - input_variables=input_variables, +function relevant_factors(Ψ::CMCompositeEstimand; adjustment_method=BackdoorAdjustment(), verbosity=1) + outcome_factor = TMLE.get_or_set_conditional_distribution_from_natural!( + Ψ.scm, TMLE.outcome(Ψ), + TMLE.outcome_parents(adjustment_method, Ψ); + verbosity=verbosity + ) + treatment_factors = (get_conditional_distribution(Ψ.scm, treatment, parents(Ψ.scm, treatment)) for treatment in treatments(Ψ)) + return (outcome_factor, treatment_factors...) +end + +function MLJBase.fit!(Ψ::CMCompositeEstimand, dataset; adjustment_method=BackdoorAdjustment(), verbosity=1, force=false, cache=true) + for factor in relevant_factors(Ψ; adjustment_method=adjustment_method, verbosity=verbosity) + fit!(factor, dataset; + cache=cache, verbosity=verbosity, force=force ) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 453971e3..a7edd749 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -95,3 +95,69 @@ function naive_plugin_estimate!(Ψ::CMCompositeEstimand, dataset; return naive_plugin_estimate(Ψ) end +""" + +# Algorithm + +1. Split the dataset in v-folds splits +2. For each split, fit the relevant factors of the distribution on the training sets +3. Perform cross-validated selection of the fluctuation model + a. For each split compute the clever covariate and offset on the validation sets + b. Fit a fluctuate on the full validation set + c. Iterate until epsilon = 0 (usually one step in our case) +4. Compute the targeted estimate + a. For each validation set compute the estimate from the fluctuated factors + b. Average those estimates to retrieve the targeted estimate +5. Compute the estimated variance + a. For each validation set compute the variance of the IC from the fluctuated factors + b. Average those estimates to retrieve the IC variance and divide per root n + +Discussion points: + +- More of a CV question, why does the cross-validated selector solves the optimization problem presented in the paper? +""" +function cvtmle!(Ψ::CMCompositeEstimand, dataset; + resampling=nothing, + adjustment_method=BackdoorAdjustment(), + verbosity=1, + force=false, + ps_lowerbound=1e-8, + weighted_fluctuation=false) + + # Check the estimand against the dataset + check_treatment_levels(Ψ, dataset) + # Initial fit of the SCM's equations + verbosity >= 1 && @info "Fitting the required equations..." + fit!(Ψ, dataset; + resampling=resampling, + adjustment_method=adjustment_method, + verbosity=verbosity, + force=force + ) + # Get propensity score truncation threshold + ps_lowerbound = ps_lower_bound(Ψ, ps_lowerbound) + # Fit Fluctuation + verbosity >= 1 && @info "Performing TMLE..." + Q⁰ = get_outcome_model(Ψ) + X, y = Q⁰.data + Q = machine( + Fluctuation(Ψ, 0.1, ps_lowerbound, weighted_fluctuation), + X, + y + ) + fit!(Q, verbosity=verbosity-1) + # Estimation results after TMLE + IC, Ψ̂ = gradient_and_estimate(Ψ, Q; ps_lowerbound=ps_lowerbound) + verbosity >= 1 && @info "Done." + return TMLEstimate(Ψ̂, IC), Q +end + +""" + +# Algorithm + +1. Split the dataset in v-folds splits +2. For each split, fit the relevant factors of the distribution on the training sets +3. +""" +function easy_cvtmle!() end diff --git a/src/distribution_factors.jl b/src/distribution_factors.jl new file mode 100644 index 00000000..0b4088bc --- /dev/null +++ b/src/distribution_factors.jl @@ -0,0 +1,50 @@ +abstract type DistributionFactor end + +function key(f::DistributionFactor) end + +mutable struct ConditionalDistribution <: DistributionFactor + outcome::Symbol + parents::Set{Symbol} + model::MLJBase.Supervised + machine::MLJBase.Machine + function ConditionalDistribution(outcome, parents, model) + outcome = Symbol(outcome) + parents = Set(Symbol(x) for x in parents) + outcome ∉ parents || throw(SelfReferringEquationError(outcome)) + return new(outcome, parents, model) + end +end + +ConditionalDistribution(;outcome, parents, model) = + ConditionalDistribution(outcome, parents, model) + +key(cd::ConditionalDistribution) = (cd.outcome, cd.parents) + +cond_dist_string(outcome, parents) = string(outcome, " | ", join(parents, ", ")) + +string_repr(cd::ConditionalDistribution) = cond_dist_string(cd.outcome, cd.parents) + +fit_message(cd::ConditionalDistribution) = string("Fitting Conditional Distribution Factor: ", string_repr(cd)) + +update_message(cd::ConditionalDistribution) = string("Reusing or Updating Conditional Distribution Factor: ", string_repr(cd)) + +Base.show(io::IO, ::MIME"text/plain", cd::ConditionalDistribution) = println(io, string_repr(cd)) + +function MLJBase.fit!(cd::ConditionalDistribution, dataset; cache=true, verbosity=1, force=false) + # Never fitted or model has changed + dofit = !isdefined(cd, :machine) || cd.model != cd.machine.model + if dofit + verbosity >= 1 && @info(fit_message(cd)) + featurenames = collect(cd.parents) + data = nomissing(dataset, vcat(featurenames, cd.outcome)) + X = selectcols(data, featurenames) + y = Tables.getcolumn(data, cd.outcome) + mach = machine(cd.model, X, y, cache=cache) + MLJBase.fit!(mach, verbosity=verbosity-1) + cd.machine = mach + # Also refit if force is true + else + verbosity >= 1 && @info(update_message(cd)) + MLJBase.fit!(cd.machine, verbosity=verbosity-1, force=force) + end +end \ No newline at end of file diff --git a/src/estimands.jl b/src/estimands.jl index 91d1bcf2..b47b3205 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -66,6 +66,11 @@ Function used to sort estimands for optimal estimation ordering. """ function estimand_key end +""" +Retrieves the relevant factors of the distribution to be fitted. +""" +relevant_factors(Ψ::Estimand; adjustment_method::AdjustmentMethod) + """ optimize_ordering!(estimands::Vector{<:Estimand}) diff --git a/src/scm.jl b/src/scm.jl index f1f03495..9052b1ae 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -1,108 +1,46 @@ ##################################################################### ### Structural Equation ### ##################################################################### + +SelfReferringEquationError(outcome) = + ArgumentError(string("Variable ", outcome, " appears on both sides of the equation.")) + """ # Structural Equation / SE ## Constructors -- SE(outcome, parents; model=nothing) -- SE(;outcome, parents, model=nothing) +- SE(outcome, parents) +- SE(;outcome, parents) ## Examples eq = SE(:Y, [:T, :W]) -eq = SE(:Y, [:T, :W], model = LinearRegressor()) +eq = SE(outcome=:Y, parents=[:T, :W]) """ -mutable struct StructuralEquation +struct StructuralEquation outcome::Symbol - parents::Vector{Symbol} - model::Union{Model, Nothing} - mach::Union{Machine, Nothing} - function StructuralEquation(outcome, parents, model) + parents::Set{Symbol} + function StructuralEquation(outcome, parents) + outcome = Symbol(outcome) + parents = Set(Symbol(x) for x in parents) outcome ∉ parents || throw(SelfReferringEquationError(outcome)) - return new(outcome, parents, model, nothing) + return new(outcome, parents) end end const SE = StructuralEquation -SE(outcome, parents; model=nothing) = SE(outcome, parents, model) -SE(;outcome, parents, model=nothing) = SE(outcome, parents, model) - -SelfReferringEquationError(outcome) = - ArgumentError(string("Variable ", outcome, " appears on both sides of the equation.")) - -NoModelError(eq::SE) = ArgumentError(string("It seems the following structural equation needs to be fitted.\n", - " Please provide a suitable model for it :\n\t⋆ ", eq)) - -fit_message(eq::SE) = string("Fitting Structural Equation corresponding to variable ", outcome(eq), ".") - - -equation_inputs(eq::SE, input_variables::Nothing) = parents(eq) -equation_inputs(eq::SE, input_variables) = input_variables - -""" - MLJBase.fit!(eq::SE, dataset; input_variables=nothing, verbosity=1, cache=true, force=false) - -Fits the outcome's Structural Equation using the dataset with inputs variables given by either: - -- The variables corresponding to parents(eq) if `input_variables`= nothing -- The alternative variables provided by `input_variables`. - -Extra keyword arguments are: - -- cache: Controls whether the associated MLJ.Machine will cache data. -- force: Controls whether to force the associated MLJ.Machine to refit even if neither the model or data has changed. -- verbosity: Controls the verbosity level -""" -function MLJBase.fit!(eq::SE, dataset; input_variables=nothing, verbosity=1, cache=true, force=false) - eq.model !== nothing || throw(NoModelError(eq)) - # Fit when: never fitted OR model has changed OR inputs have changed - dofit = eq.mach === nothing || eq.model != eq.mach.model - if !dofit - if isdefined(eq.mach, :data) - dofit = Tables.columnnames(eq.mach.data[1]) !== Tuple(equation_inputs(eq, input_variables)) - else - dofit = true - end - end - - if dofit - verbosity >= 1 && @info(fit_message(eq)) - input_colnames = equation_inputs(eq, input_variables) - data = nomissing(dataset, vcat(input_colnames, outcome(eq))) - X = selectcols(data, input_colnames) - y = Tables.getcolumn(data, outcome(eq)) - mach = machine(eq.model, X, y, cache=cache) - MLJBase.fit!(mach, verbosity=verbosity-1) - eq.mach = mach - # Also refit if force is true - else - verbosity >= 1 && force === true && @info(fit_message(eq)) - MLJBase.fit!(eq.mach, verbosity=verbosity-1, force=force) - end -end +SE(;outcome, parents) = SE(outcome, parents) -reset!(eq::SE) = eq.mach = nothing +outcome(se::SE) = se.outcome +parents(se::SE) = se.parents -function string_repr(eq::SE; subscript="") - eq_string = string(eq.outcome, " = f", subscript, "(", join(eq.parents, ", "), ")") - if eq.model !== nothing - eq_string = string(eq_string, ", ", nameof(typeof(eq.model)), ", fitted=", eq.mach !== nothing) - end - return eq_string -end +string_repr(eq::SE; subscript="") = string(outcome(eq), " = f", subscript, "(", join(parents(eq), ", "), ")") Base.show(io::IO, ::MIME"text/plain", eq::SE) = println(io, string_repr(eq)) -setmodel!(eq::SE, model::Nothing) = nothing -setmodel!(eq::SE, model::Model) = eq.model = model - -outcome(se::SE) = se.outcome -parents(se::SE) = se.parents - ##################################################################### ### Structural Causal Model ### ##################################################################### @@ -127,16 +65,68 @@ scm = SCM( """ struct StructuralCausalModel equations::Dict{Symbol, StructuralEquation} + factors::Dict end const SCM = StructuralCausalModel -SCM(;equations=Dict{Symbol, SE}()) = SCM(equations) -SCM(equations::Vararg{SE}) = - SCM(Dict(outcome(eq) => eq for eq in equations)) - +SCM(;equations=Dict{Symbol, SE}(), factors=Dict()) = SCM(equations, factors) +SCM(equations::Vararg{SE}; factors=Dict()) = SCM(Dict(outcome(eq) => eq for eq in equations), factors) equations(scm::SCM) = scm.equations +factors(scm::SCM) = scm.factors + +parents(scm::StructuralCausalModel, outcome::Symbol) = + haskey(equations(scm), outcome) ? parents(equations(scm)[outcome]) : Set{Symbol}() + +setequation!(scm::SCM, eq::SE) = scm.equations[outcome(eq)] = eq + +getequation(scm::StructuralCausalModel, key::Symbol) = scm.equations[key] + +get_conditional_distribution(scm::SCM, outcome::Symbol) = get_conditional_distribution(scm, outcome, parents(scm, outcome)) +get_conditional_distribution(scm::SCM, outcome::Symbol, parents) = scm.factors[(outcome, Set(parents))] + +NoAvailableConditionalDistributionError(scm, outcome, conditioning_set) = ArgumentError(string( + "Could not find a conditional distribution for either : \n - The required: ", + cond_dist_string(outcome, conditioning_set), + "\n - The natural (that could be used as a template): ", + cond_dist_string(outcome, parents(scm, outcome)), + "\nSet at least one of them using `set_conditional_distribution!` before proceeding to estimation." + )) + +UsingNaturalDistributionLog(scm, outcome, conditioning_set) = string( + "Could not retrieve a conditional distribution for the required: ", + cond_dist_string(outcome, conditioning_set), + ".\n Will use the model class found in the natural candidate: ", + cond_dist_string(outcome, parents(scm, outcome)) +) + +function get_or_set_conditional_distribution_from_natural!(scm::SCM, outcome::Symbol, conditioning_set::Set{Symbol}; verbosity=1) + try + return get_conditional_distribution(scm, outcome, conditioning_set) + catch KeyError + try + template_factor = get_conditional_distribution(scm, outcome) + set_conditional_distribution!(scm, outcome, conditioning_set, template_factor.model) + factor = get_conditional_distribution(scm, outcome, conditioning_set) + verbosity > 0 && @info(UsingNaturalDistributionLog(scm, outcome, conditioning_set)) + return factor + catch KeyError + throw(NoAvailableConditionalDistributionError(scm, outcome, conditioning_set)) + end + end +end + +set_conditional_distribution!(scm::SCM, outcome::Symbol, model) = + set_conditional_distribution!(scm, outcome, parents(scm, outcome), model) + +function set_conditional_distribution!(scm::SCM, outcome::Symbol, parents, model) + factor = ConditionalDistribution(outcome, parents, model) + setfactor!(scm, factor) +end + +setfactor!(scm::SCM, factor::DistributionFactor) = + scm.factors[key(factor)] = factor function string_repr(scm::SCM) scm_string = """ @@ -154,40 +144,6 @@ end Base.show(io::IO, ::MIME"text/plain", scm::SCM) = println(io, string_repr(scm)) -function Base.push!(scm::SCM, eq::SE) - key = outcome(eq) - scm[key] = eq -end - -function Base.setindex!(scm::SCM, eq::SE, key::Symbol) - if haskey(scm.equations, key) - throw(AlreadyAssignedError(key)) - end - scm.equations[key] = eq -end - -Base.getindex(scm::SCM, key::Symbol) = scm.equations[key] - -function Base.getproperty(scm::StructuralCausalModel, key::Symbol) - hasfield(StructuralCausalModel, key) && return getfield(scm, key) - return scm.equations[key] -end - -parents(scm::StructuralCausalModel, key::Symbol) = - haskey(equations(scm), key) ? parents(scm[key]) : Symbol[] - -function MLJBase.fit!(scm::SCM, dataset; verbosity=1, cache=true, force=false) - for eq in values(equations(scm)) - fit!(eq, dataset; verbosity=verbosity, cache=cache, force=force) - end -end - -function reset!(scm::SCM) - for eq in values(equations(scm)) - reset!(eq) - end -end - """ is_upstream(var₁, var₂, scm::SCM) @@ -199,27 +155,18 @@ function is_upstream(var₁, var₂, scm::SCM) if var₁ ∈ upstream return true else - upstream = vcat((parents(scm, v) for v in upstream)...) + upstream = union((parents(scm, v) for v in upstream)...) end isempty(upstream) && return false end end -function upstream_variables(var, scm::SCM) - upstream = parents(scm, var) - current_upstream = upstream - while true - current_upstream = vcat((parents(scm, v) for v in current_upstream)...) - union!(upstream, current_upstream) - isempty(current_upstream) && return upstream - end -end ##################################################################### ### StaticConfoundedModel ### ##################################################################### -vcat_covariates(treatment, confounders, covariates::Nothing) = vcat(treatment, confounders) -vcat_covariates(treatment, confounders, covariates) = vcat(treatment, confounders, covariates) +combine_outcome_parents(treatment, confounders, covariates::Nothing) = vcat(treatment, confounders) +combine_outcome_parents(treatment, confounders, covariates) = vcat(treatment, confounders, covariates) """ StaticConfoundedModel( @@ -236,22 +183,20 @@ The `outcome_model` and `treatment_model` define the relationship between the outcome (resp. treatment) and their ancestors. """ function StaticConfoundedModel( - outcome::Symbol, treatment::Symbol, confounders::Union{Symbol, AbstractVector{Symbol}}; - covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, + outcome, treatment, confounders; + covariates = nothing, outcome_model = TreatmentTransformer() |> LinearRegressor(), - treatment_model = LinearBinaryClassifier() + treatment_model = LinearBinaryClassifier(), ) - Yeq = SE( - outcome, - vcat_covariates(treatment, confounders, covariates), - outcome_model + Yeq = SE(outcome, combine_outcome_parents(treatment, confounders, covariates)) + Teq = SE(treatment, vcat(confounders)) + outcome_factor = ConditionalDistribution(Yeq.outcome, Yeq.parents, outcome_model) + treatment_factor = ConditionalDistribution(Teq.outcome, Teq.parents, treatment_model) + factors = Dict( + key(outcome_factor) => outcome_factor, + key(treatment_factor) => treatment_factor ) - Teq = SE( - treatment, - vcat(confounders), - treatment_model - ) - return StructuralCausalModel(Yeq, Teq) + return StructuralCausalModel(Yeq, Teq; factors) end """ @@ -274,14 +219,27 @@ The `outcome_model` and `treatment_model` define the relationships between the outcomes (resp. treatments) and their ancestors. """ function StaticConfoundedModel( - outcomes::Vector{Symbol}, - treatments::Vector{Symbol}, - confounders::Union{Symbol, AbstractVector{Symbol}}; - covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, + outcomes::AbstractVector, + treatments::AbstractVector, + confounders; + covariates = nothing, outcome_model = TreatmentTransformer() |> LinearRegressor(), treatment_model = LinearBinaryClassifier() ) - Yequations = (SE(outcome, vcat_covariates(treatments, confounders, covariates), outcome_model) for outcome in outcomes) - Tequations = (SE(treatment, vcat(confounders), treatment_model) for treatment in treatments) - return SCM(Yequations..., Tequations...) + Yequations = (SE(outcome, combine_outcome_parents(treatments, confounders, covariates)) for outcome in outcomes) + Tequations = (SE(treatment, vcat(confounders)) for treatment in treatments) + factors = Dict( + (outcome(Yeq) => Dict(parents(Yeq) => outcome_model) for Yeq in Yequations)..., + (outcome(Teq) => Dict(parents(Teq) => treatment_model) for Teq in Tequations)... + ) + factors = Dict() + for eq in Yequations + factor = ConditionalDistribution(eq.outcome, eq.parents, outcome_model) + factors[key(factor)] = factor + end + for eq in Tequations + factor = ConditionalDistribution(eq.outcome, eq.parents, treatment_model) + factors[key(factor)] = factor + end + return SCM(Yequations..., Tequations...; factors=factors) end \ No newline at end of file diff --git a/test/counterfactual_mean_based/estimands.jl b/test/counterfactual_mean_based/estimands.jl index b610a220..0e6b103c 100644 --- a/test/counterfactual_mean_based/estimands.jl +++ b/test/counterfactual_mean_based/estimands.jl @@ -74,13 +74,13 @@ end @testset "Test constructors" begin Ψ = CM(outcome=:Y, treatment=(T=1,), confounders=[:W]) - @test parents(Ψ.scm.Y) == [:T, :W] + @test parents(Ψ.scm, :Y) == Set([:T, :W]) Ψ = ATE(outcome=:Y, treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), confounders=:W) - @test parents(Ψ.scm.Y) == [:T₁, :T₂, :W] + @test parents(Ψ.scm, :Y) == Set([:T₁, :T₂, :W]) Ψ = IATE(outcome=:Y, treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), confounders=:W) - @test parents(Ψ.scm.Y) == [:T₁, :T₂, :W] + @test parents(Ψ.scm, :Y) == Set([:T₁, :T₂, :W]) end @@ -100,25 +100,36 @@ end ) scm = SCM( - SE(:Ycat, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], TreatmentTransformer() |> LinearBinaryClassifier()), - SE(:Ycont, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], TreatmentTransformer() |> LinearRegressor()), - SE(:T₁, [:W₁₁, :W₁₂], LinearBinaryClassifier()), - SE(:T₂, [:W₂₁, :W₂₂], LinearBinaryClassifier()), + SE(:Ycat, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C]), + SE(:Ycont, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C]), + SE(:T₁, [:W₁₁, :W₁₂]), + SE(:T₂, [:W₂₁, :W₂₂]), ) - - # Fits CM required equations + set_conditional_distribution!(scm, :Ycat, with_encoder(LinearBinaryClassifier())) + set_conditional_distribution!(scm, :Ycont, with_encoder(LinearRegressor())) + set_conditional_distribution!(scm, :T₁, LinearBinaryClassifier()) + set_conditional_distribution!(scm, :T₂, LinearBinaryClassifier()) + # Fits CM required factors Ψ = CM( scm, treatment=(T₁=1,), outcome=:Ycat ) - fit!(Ψ, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=0) - @test scm.Ycat.mach isa Machine - @test keys(scm.Ycat.mach.data[1]) == (:T₁, :W₁₁, :W₁₂, :C) - @test scm.T₁.mach isa Machine - @test keys(scm.T₁.mach.data[1]) == (:W₁₁, :W₁₂) - @test scm.Ycont.mach isa Nothing - @test scm.T₂.mach isa Nothing + log_sequence = ( + (:info, TMLE.UsingNaturalDistributionLog(scm, :Ycat, Set([:W₁₂, :W₁₁, :T₁, :C]))), + (:info, "Fitting Conditional Distribution Factor: Ycat | W₁₂, W₁₁, T₁, C"), + (:info, "Fitting Conditional Distribution Factor: T₁ | W₁₂, W₁₁"), + ) + @test_logs log_sequence... fit!(Ψ, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=1) + outcome_dist = get_conditional_distribution(scm, :Ycat, Set([:W₁₂, :W₁₁, :T₁, :C])) + @test keys(fitted_params(outcome_dist.machine)) == (:linear_binary_classifier, :treatment_transformer) + @test keys(outcome_dist.machine.data[1]) == (:W₁₂, :W₁₁, :T₁, :C) + treatment_dist = get_conditional_distribution(scm, :T₁, Set([:W₁₁, :W₁₂])) + @test fitted_params(treatment_dist.machine).features == [:W₁₂, :W₁₁] + @test keys(treatment_dist.machine.data[1]) == (:W₁₂, :W₁₁) + # The natural outcome distribution has not been fitted for instance + natural_outcome_dist = get_conditional_distribution(scm, :Ycat, Set([:T₂, :W₁₂, :W₂₂, :W₂₁, :W₁₁, :T₁, :C])) + @test !isdefined(natural_outcome_dist, :machine) # Fits ATE required equations that haven't been fitted yet. Ψ = ATE( @@ -126,16 +137,21 @@ end treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), outcome =:Ycat, ) + conditioning_set = Set([:W₁₂, :W₂₂, :W₂₁, :W₁₁, :T₁, :T₂]) log_sequence = ( - (:info, "Fitting Structural Equation corresponding to variable T₂."), - (:info, "Fitting Structural Equation corresponding to variable Ycat.") + (:info, TMLE.UsingNaturalDistributionLog(scm, :Ycat, conditioning_set)), + (:info, "Fitting Conditional Distribution Factor: Ycat | W₁₂, W₂₂, W₂₁, W₁₁, T₁, T₂"), + (:info, "Reusing or Updating Conditional Distribution Factor: T₁ | W₁₂, W₁₁"), + (:info, "Fitting Conditional Distribution Factor: T₂ | W₂₂, W₂₁"), ) @test_logs log_sequence... fit!(Ψ, dataset, verbosity=1) - @test scm.Ycat.mach isa Machine - keys(scm.Ycat.mach.data[1]) - @test scm.T₁.mach isa Machine - @test scm.T₂.mach isa Machine - @test scm.Ycont.mach isa Nothing + + outcome_dist = get_conditional_distribution(scm, :Ycat, conditioning_set) + @test keys(fitted_params(outcome_dist.machine)) == (:linear_binary_classifier, :treatment_transformer) + @test Set(keys(outcome_dist.machine.data[1])) == conditioning_set + treatment_dist = get_conditional_distribution(scm, :T₂, Set([:W₂₂, :W₂₁])) + @test fitted_params(treatment_dist.machine).features == [:W₂₂, :W₂₁] + @test keys(treatment_dist.machine.data[1]) == (:W₂₂, :W₂₁) # Fits IATE required equations that haven't been fitted yet. Ψ = IATE( @@ -143,21 +159,18 @@ end treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), outcome =:Ycont, ) + conditioning_set = Set([:W₁₂, :W₂₂, :W₂₁, :W₁₁, :T₁, :T₂]) log_sequence = ( - (:info, "Fitting Structural Equation corresponding to variable Ycont."), + (:info, TMLE.UsingNaturalDistributionLog(scm, :Ycont, conditioning_set)), + (:info, "Fitting Conditional Distribution Factor: Ycont | W₁₂, W₂₂, W₂₁, W₁₁, T₁, T₂"), + (:info, "Reusing or Updating Conditional Distribution Factor: T₁ | W₁₂, W₁₁"), + (:info, "Reusing or Updating Conditional Distribution Factor: T₂ | W₂₂, W₂₁"), ) @test_logs log_sequence... fit!(Ψ, dataset, verbosity=1) - @test scm.Ycat.mach isa Machine - @test scm.T₁.mach isa Machine - @test scm.T₂.mach isa Machine - @test scm.Ycont.mach isa Machine + outcome_dist = get_conditional_distribution(scm, :Ycont, conditioning_set) + @test keys(fitted_params(outcome_dist.machine)) == (:linear_regressor, :treatment_transformer) + @test Set(keys(outcome_dist.machine.data[1])) == conditioning_set - # Change a model - scm.Ycont.model = TreatmentTransformer() |> LinearRegressor(fit_intercept=false) - log_sequence = ( - (:info, "Fitting Structural Equation corresponding to variable Ycont."), - ) - @test_logs log_sequence... fit!(Ψ, dataset, verbosity=1) end @testset "Test optimize_ordering" begin diff --git a/test/distribution_factors.jl b/test/distribution_factors.jl new file mode 100644 index 00000000..0ed19925 --- /dev/null +++ b/test/distribution_factors.jl @@ -0,0 +1,69 @@ +module TestDistributionFactors + +using Test +using TMLE +using MLJGLMInterface +using MLJBase + +@testset "Test ConditionalDistribution(outcome, parents, model, resampling)" begin + # Default constructor + cond_dist = ConditionalDistribution("Y", [:W, "T", :C], LinearBinaryClassifier()) + @test cond_dist.outcome == :Y + @test cond_dist.parents == Set([:W, :T, :C]) + @test cond_dist.model == LinearBinaryClassifier() + @test TMLE.key(cond_dist) == (:Y, Set([:W, :T, :C])) + # String representation + @test TMLE.string_repr(cond_dist) == "Y | T, W, C" +end + +@testset "Test fit!" begin + X, y = make_regression() + dataset = (Y = y, X₁ = X.x1, X₂ = X.x2) + cond_dist = ConditionalDistribution(:Y, [:X₁, :X₂], LinearRegressor()) + # Initial fit + log_sequence = ( + (:info, TMLE.fit_message(cond_dist)), + (:info, "Training machine(LinearRegressor(fit_intercept = true, …), …).") + ) + @test_logs log_sequence... MLJBase.fit!(cond_dist, dataset, verbosity=2) + fp = fitted_params(cond_dist.machine) + @test fp.features == [:X₂, :X₁] + @test fp.intercept != 0. + # Change the model + cond_dist.model = LinearRegressor(fit_intercept=false) + log_sequence = ( + (:info, TMLE.fit_message(cond_dist)), + (:info, "Training machine(LinearRegressor(fit_intercept = false, …), …).") + ) + @test_logs log_sequence... MLJBase.fit!(cond_dist, dataset, verbosity=2) + fp = fitted_params(cond_dist.machine) + @test fp.features == [:X₂, :X₁] + @test fp.intercept == 0. + # Change model's hyperparameter + cond_dist.model.fit_intercept = true + log_sequence = ( + (:info, TMLE.update_message(cond_dist)), + (:info, "Updating machine(LinearRegressor(fit_intercept = true, …), …).") + ) + @test_logs log_sequence... MLJBase.fit!(cond_dist, dataset, verbosity=2) + fp = fitted_params(cond_dist.machine) + @test fp.features == [:X₂, :X₁] + @test fp.intercept != 0. + # Forcing refit + log_sequence = ( + (:info, TMLE.update_message(cond_dist)), + (:info, "Training machine(LinearRegressor(fit_intercept = true, …), …).") + ) + @test_logs log_sequence... MLJBase.fit!(cond_dist, dataset, verbosity=2, force = true) + # No refit + log_sequence = ( + (:info, TMLE.update_message(cond_dist)), + (:info, "Not retraining machine(LinearRegressor(fit_intercept = true, …), …). Use `force=true` to force.") + ) + @test_logs log_sequence... MLJBase.fit!(cond_dist, dataset, verbosity=2) + +end + +end + +true \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 1767cd0c..afe47204 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Test @time begin + @test include("distribution_factors.jl") @test include("scm.jl") @test include("utils.jl") @test include("missing_management.jl") diff --git a/test/scm.jl b/test/scm.jl index 475910dc..1f4a4ba5 100644 --- a/test/scm.jl +++ b/test/scm.jl @@ -12,28 +12,32 @@ using MLJBase # Incorrect Structural Equations @test_throws TMLE.SelfReferringEquationError(:Y) SE(:Y, [:Y]) # Correct Structural Equations - eq = SE(:Y, [:T, :W], LinearBinaryClassifier()) - @test eq isa SE - @test TMLE.string_repr(eq) == "Y = f(T, W), LinearBinaryClassifier, fitted=false" - - eq = SE(:Y, [:T, :W]) + # - outcome is a Symbol + # - parents is a Set of Symbols + eq = SE("Y", [:T, "W"]) + @test TMLE.outcome(eq) == :Y + @test TMLE.parents(eq) == Set([:T, :W]) @test TMLE.string_repr(eq) == "Y = f(T, W)" end @testset "Test SCM" begin scm = SCM() - # setindex! + # setequation! Yeq = SE(:Y, [:T, :W, :C]) - scm[:Y] = Yeq - # push! + setequation!(scm, Yeq) Teq = SE(:T, [:W]) - push!(scm, Teq) - # getindex - @test scm[:Y] === Yeq - # getproperty - @test scm.T === Teq - # Already defined equation - @test_throws TMLE.AlreadyAssignedError(:Y) push!(scm, SE(:Y, [:T])) + setequation!(scm, Teq) + # getequation + @test getequation(scm, :Y) == Yeq + @test getequation(scm, :T) == Teq + # set and get conditional_distribution with implied natural parents + cond_dist = set_conditional_distribution!(scm, :T, LinearBinaryClassifier()) + @test get_conditional_distribution(scm, :T) == cond_dist + # set and get conditional_distribution with provided parents + cond_dist = set_conditional_distribution!(scm, :Y, LinearBinaryClassifier()) + @test_throws KeyError get_conditional_distribution(scm, :Y, Set([:W, :T])) + cond_dist = TMLE.get_or_set_conditional_distribution_from_natural!(scm, :Y, Set([:W, :T])) + @test cond_dist.model == get_conditional_distribution(scm, :Y).model # string representation @test TMLE.string_repr(scm) == "Structural Causal Model:\n-----------------------\nT = f₁(W)\nY = f₂(T, W, C)\n" end @@ -41,23 +45,22 @@ end @testset "Test StaticConfoundedModel" begin # With covariates scm = StaticConfoundedModel(:Y, :T, [:W₁, :W₂], covariates=[:C₁]) - @test TMLE.parents(scm, :Y) == [:T, :W₁, :W₂, :C₁] - @test TMLE.parents(scm, :T) == [:W₁, :W₂] - @test scm.Y.model isa ProbabilisticPipeline - @test scm.T.model isa LinearBinaryClassifier - @test scm.T.model.fit_intercept === true + @test TMLE.parents(scm, :Y) == Set([:T, :W₁, :W₂, :C₁]) + @test TMLE.parents(scm, :T) == Set([:W₁, :W₂]) + @test get_conditional_distribution(scm, :T).model isa LinearBinaryClassifier + @test get_conditional_distribution(scm, :Y).model isa ProbabilisticPipeline # Without covariates scm = StaticConfoundedModel(:Y, :T, :W₁, outcome_model=LinearBinaryClassifier(), treatment_model=LinearBinaryClassifier(fit_intercept=false)) - @test TMLE.parents(scm, :Y) == [:T, :W₁] - @test TMLE.parents(scm, :T) == [:W₁] - @test scm.Y.model isa LinearBinaryClassifier - @test scm.T.model.fit_intercept === false + @test TMLE.parents(scm, :Y) == Set([:T, :W₁]) + @test TMLE.parents(scm, :T) == Set([:W₁]) + @test get_conditional_distribution(scm, :Y).model isa LinearBinaryClassifier + @test get_conditional_distribution(scm, :T).model.fit_intercept === false # With 1 covariate scm = StaticConfoundedModel(:Y, :T, :W₁, covariates=:C₁) - @test TMLE.parents(scm, :Y) == [:T, :W₁, :C₁] - @test TMLE.parents(scm, :T) == [:W₁] + @test TMLE.parents(scm, :Y) == Set([:T, :W₁, :C₁]) + @test TMLE.parents(scm, :T) == Set([:W₁]) # With multiple outcomes and treatments scm = StaticConfoundedModel( @@ -67,79 +70,28 @@ end covariates=[:C] ) - Yparents = [:T₁, :T₂, :W₁, :W₂, :W₃, :C] - @test parents(scm.Y₁) == Yparents - @test parents(scm.Y₂) == Yparents - @test parents(scm.Y₃) == Yparents - - Tparents = [:W₁, :W₂, :W₃] - @test parents(scm.T₁) == Tparents - @test parents(scm.T₂) == Tparents - -end - -@testset "Test fit!" begin - rng = StableRNG(123) - n = 100 - dataset = ( - Ycat = categorical(rand(rng, [0, 1], n)), - Ycont = rand(rng, n), - T₁ = categorical(rand(rng, [0, 1], n)), - T₂ = categorical(rand(rng, [0, 1], n)), - W₁₁ = rand(rng, n), - W₁₂ = rand(rng, n), - W₂₁ = rand(rng, n), - W₂₂ = rand(rng, n), - C = rand(rng, n) - ) - - scm = SCM( - SE(:Ycat, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], TreatmentTransformer() |> LinearBinaryClassifier()), - SE(:Ycont, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], TreatmentTransformer() |> LinearRegressor()), - SE(:T₁, [:W₁₁, :W₁₂], LinearBinaryClassifier()), - SE(:T₂, [:W₂₁, :W₂₂], LinearBinaryClassifier()), - ) - - # Fits all equations in SCM - fit_log_sequence = ( - (:info, "Fitting Structural Equation corresponding to variable Ycont."), - (:info, "Fitting Structural Equation corresponding to variable Ycat."), - (:info, "Fitting Structural Equation corresponding to variable T₁."), - (:info, "Fitting Structural Equation corresponding to variable T₂."), - ) - @test_logs fit_log_sequence... fit!(scm, dataset, verbosity = 1) - for (key, eq) in equations(scm) - @test eq.mach isa Machine - @test isdefined(eq.mach, :data) + Yparents = Set([:T₁, :T₂, :W₁, :W₂, :W₃, :C]) + for Y in [:Y₁, :Y₂, :Y₃] + @test parents(scm, Y) == Yparents + @test get_conditional_distribution(scm, Y, Yparents).model isa ProbabilisticPipeline end - # Refit will not do anything - nofit_log_sequence = () - @test_logs nofit_log_sequence... fit!(scm, dataset, verbosity = 1) - # Force refit - @test_logs fit_log_sequence... fit!(scm, dataset, verbosity = 1, force=true) - # Reset scm - reset!(scm) - for (key, eq) in equations(scm) - @test eq.mach isa Nothing + Tparents = Set([:W₁, :W₂, :W₃]) + for T in [:T₁, :T₂] + @test parents(scm, T) == Tparents + @test get_conditional_distribution(scm, T, Tparents).model isa LinearBinaryClassifier end end -@testset "Test is_upstream && upstream_variables" begin +@testset "Test is_upstream" begin scm = SCM( SE(:T₁, [:W₁]), SE(:T₂, [:W₂]), SE(:Y, [:T₁, :T₂]), ) - parents(scm, :Y) @test TMLE.is_upstream(:T₁, :Y, scm) === true @test TMLE.is_upstream(:W₁, :Y, scm) === true @test TMLE.is_upstream(:Y, :T₁, scm) === false - - @test TMLE.upstream_variables(:Y, scm) == [:T₁, :T₂, :W₁, :W₂] - @test TMLE.upstream_variables(:T₁, scm) == [:W₁] - @test TMLE.upstream_variables(:W₁, scm) == Symbol[] - end end From c9e62c261e08da7f5c2d8ac27f145f6f24005089 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 1 Sep 2023 16:46:38 +0100 Subject: [PATCH 063/151] add more content --- src/counterfactual_mean_based/estimands.jl | 14 ++-- src/counterfactual_mean_based/estimators.jl | 71 +++++------------ src/counterfactual_mean_based/fluctuation.jl | 78 ++++++++++++------- src/counterfactual_mean_based/gradient.jl | 29 +++---- .../offset_and_covariate.jl | 43 +++++----- src/distribution_factors.jl | 69 ++++++++++++++-- src/estimands.jl | 5 +- test/counterfactual_mean_based/fluctuation.jl | 61 ++++++++++++--- test/counterfactual_mean_based/gradient.jl | 17 ++-- .../non_regression_test.jl | 4 +- 10 files changed, 234 insertions(+), 157 deletions(-) diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index ae62dec7..bea1ec9d 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -195,24 +195,26 @@ function check_parameter_against_scm(scm::SCM, outcome, treatment) end end -function relevant_factors(Ψ::CMCompositeEstimand; adjustment_method=BackdoorAdjustment(), verbosity=1) - outcome_factor = TMLE.get_or_set_conditional_distribution_from_natural!( - Ψ.scm, TMLE.outcome(Ψ), - TMLE.outcome_parents(adjustment_method, Ψ); +function get_relevant_factors(Ψ::CMCompositeEstimand; adjustment_method=BackdoorAdjustment(), verbosity=1) + outcome_factor = get_or_set_conditional_distribution_from_natural!( + Ψ.scm, outcome(Ψ), + outcome_parents(adjustment_method, Ψ); verbosity=verbosity ) treatment_factors = (get_conditional_distribution(Ψ.scm, treatment, parents(Ψ.scm, treatment)) for treatment in treatments(Ψ)) - return (outcome_factor, treatment_factors...) + return NamedTuple{(outcome_factor.outcome, (tf.outcome for tf in treatment_factors)...)}([outcome_factor, treatment_factors...]) end function MLJBase.fit!(Ψ::CMCompositeEstimand, dataset; adjustment_method=BackdoorAdjustment(), verbosity=1, force=false, cache=true) - for factor in relevant_factors(Ψ; adjustment_method=adjustment_method, verbosity=verbosity) + relevant_factors = get_relevant_factors(Ψ; adjustment_method=adjustment_method, verbosity=verbosity) + for factor in relevant_factors fit!(factor, dataset; cache=cache, verbosity=verbosity, force=force ) end + return relevant_factors end function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: CMCompositeEstimand diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index a7edd749..2803cc74 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -19,7 +19,8 @@ Performs Targeted Minimum Loss Based Estimation of the target estimand. - ps_lowerbound: The propensity score will be truncated to respect this lower bound. - weighted_fluctuation: To use a weighted fluctuation instead of the vanilla TMLE, can improve stability. """ -function tmle!(Ψ::CMCompositeEstimand, dataset; +function tmle!(Ψ::CMCompositeEstimand, dataset; + resampling=nothing, adjustment_method=BackdoorAdjustment(), verbosity=1, force=false, @@ -30,27 +31,26 @@ function tmle!(Ψ::CMCompositeEstimand, dataset; check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's equations verbosity >= 1 && @info "Fitting the required equations..." - fit!(Ψ, dataset; + initial_factors = fit!(Ψ, dataset; adjustment_method=adjustment_method, verbosity=verbosity, force=force ) # Get propensity score truncation threshold - ps_lowerbound = ps_lower_bound(Ψ, ps_lowerbound) + n = nrows(initial_factors[outcome(Ψ)].machine.data[2]) + ps_lowerbound = ps_lower_bound(n, ps_lowerbound) # Fit Fluctuation verbosity >= 1 && @info "Performing TMLE..." - Q⁰ = get_outcome_model(Ψ) - X, y = Q⁰.data - Q = machine( - Fluctuation(Ψ, 0.1, ps_lowerbound, weighted_fluctuation), - X, - y + targeted_factors = fluctuate(initial_factors, Ψ; + tol=nothing, + verbosity=verbosity, + weighted_fluctuation=weighted_fluctuation, + ps_lowerbound=ps_lowerbound ) - fit!(Q, verbosity=verbosity-1) # Estimation results after TMLE - IC, Ψ̂ = gradient_and_estimate(Ψ, Q; ps_lowerbound=ps_lowerbound) + IC, Ψ̂ = gradient_and_estimate(Ψ, targeted_factors; ps_lowerbound=ps_lowerbound) verbosity >= 1 && @info "Done." - return TMLEstimate(Ψ̂, IC), Q + return TMLEstimate(Ψ̂, IC), targeted_factors end function ose!(Ψ::CMCompositeEstimand, dataset; @@ -77,8 +77,6 @@ function ose!(Ψ::CMCompositeEstimand, dataset; return OSEstimate(Ψ̂ + mean(IC), IC), Q end -naive_plugin_estimate(Ψ::CMCompositeEstimand) = mean(counterfactual_aggregate(Ψ, get_outcome_model(Ψ))) - function naive_plugin_estimate!(Ψ::CMCompositeEstimand, dataset; adjustment_method=BackdoorAdjustment(), verbosity=1, @@ -87,12 +85,16 @@ function naive_plugin_estimate!(Ψ::CMCompositeEstimand, dataset; check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's equations verbosity >= 1 && @info "Fitting the required equations..." - fit!(Ψ, dataset; - adjustment_method=adjustment_method, + Q = get_or_set_conditional_distribution_from_natural!( + Ψ.scm, outcome(Ψ), + outcome_parents(adjustment_method, Ψ); + verbosity=verbosity + ) + fit!(Q, dataset; verbosity=verbosity, force=force ) - return naive_plugin_estimate(Ψ) + return mean(counterfactual_aggregate(Ψ, Q)) end """ @@ -123,41 +125,6 @@ function cvtmle!(Ψ::CMCompositeEstimand, dataset; force=false, ps_lowerbound=1e-8, weighted_fluctuation=false) - - # Check the estimand against the dataset - check_treatment_levels(Ψ, dataset) - # Initial fit of the SCM's equations - verbosity >= 1 && @info "Fitting the required equations..." - fit!(Ψ, dataset; - resampling=resampling, - adjustment_method=adjustment_method, - verbosity=verbosity, - force=force - ) - # Get propensity score truncation threshold - ps_lowerbound = ps_lower_bound(Ψ, ps_lowerbound) - # Fit Fluctuation - verbosity >= 1 && @info "Performing TMLE..." - Q⁰ = get_outcome_model(Ψ) - X, y = Q⁰.data - Q = machine( - Fluctuation(Ψ, 0.1, ps_lowerbound, weighted_fluctuation), - X, - y - ) - fit!(Q, verbosity=verbosity-1) - # Estimation results after TMLE - IC, Ψ̂ = gradient_and_estimate(Ψ, Q; ps_lowerbound=ps_lowerbound) - verbosity >= 1 && @info "Done." - return TMLEstimate(Ψ̂, IC), Q end -""" - -# Algorithm -1. Split the dataset in v-folds splits -2. For each split, fit the relevant factors of the distribution on the training sets -3. -""" -function easy_cvtmle!() end diff --git a/src/counterfactual_mean_based/fluctuation.jl b/src/counterfactual_mean_based/fluctuation.jl index a7e3cc6d..66599467 100644 --- a/src/counterfactual_mean_based/fluctuation.jl +++ b/src/counterfactual_mean_based/fluctuation.jl @@ -1,31 +1,63 @@ mutable struct ContinuousFluctuation <: MLJBase.Probabilistic Ψ::CMCompositeEstimand - tol::Float64 + initial_factors::NamedTuple + tol::Union{Nothing, Float64} ps_lowerbound::Float64 weighted::Bool end mutable struct BinaryFluctuation <: MLJBase.Probabilistic Ψ::CMCompositeEstimand - tol::Float64 + initial_factors::NamedTuple + tol::Union{Nothing, Float64} ps_lowerbound::Float64 weighted::Bool end FluctuationModel = Union{ContinuousFluctuation, BinaryFluctuation} -function Fluctuation(Ψ, tol, ps_lowerbound, weighted) - outcome_scitype = scitype(get_outcome_datas(Ψ)[2]) +function Fluctuation(Ψ, initial_factors; tol=nothing, ps_lowerbound=1e-8, weighted=false) + outcome_scitype = scitype(initial_factors[outcome(Ψ)].machine.data[2]) if outcome_scitype <: AbstractVector{<:MLJBase.Continuous} - return ContinuousFluctuation(Ψ, tol, ps_lowerbound, weighted) + return ContinuousFluctuation(Ψ, initial_factors, tol, ps_lowerbound, weighted) elseif outcome_scitype <: AbstractVector{<:Finite} - return BinaryFluctuation(Ψ, tol, ps_lowerbound, weighted) + return BinaryFluctuation(Ψ, initial_factors, tol, ps_lowerbound, weighted) else throw(ArgumentError("Cannot proceed with outcome with target_scitype: $outcome_scitype")) end end +function fluctuate(initial_factors, Ψ; + tol=nothing, + verbosity=1, + weighted_fluctuation=false, + ps_lowerbound=1e-8 + ) + Qfluct_model = Fluctuation(Ψ, initial_factors; + tol=tol, + ps_lowerbound=ps_lowerbound, + weighted=weighted_fluctuation + ) + Q⁰, G⁰ = splitQG(initial_factors, Ψ) + # Retrieve dataset from Q⁰ + X, y = Q⁰.machine.data + dataset = merge(X, NamedTuple{(Q⁰.outcome, )}([y])) + Qfluct = ConditionalDistribution( + Q⁰.outcome, + Q⁰.parents, + Qfluct_model + ) + fit!(Qfluct, dataset, verbosity=verbosity-1) + return merge(NamedTuple{(Q⁰.outcome,)}([Qfluct]), G⁰) +end + +function splitQG(factors, Ψ) + Q = factors[outcome(Ψ)] + G = (; (key => factors[key] for key in treatments(Ψ))...) + return Q, G +end + epsilon(Qstar) = fitted_params(fitted_params(Qstar).fitresult).coef[1] one_dimensional_path(model::ContinuousFluctuation) = LinearRegressor(fit_intercept=false, offsetcol = :offset) @@ -40,32 +72,24 @@ The GLM models require inputs of the same type fluctuation_input(covariate::AbstractVector{T1}, offset::AbstractVector{T2}) where {T1, T2} = (covariate=covariate, offset=convert(Vector{T1}, offset)) - training_expected_value(Q::Machine{<:FluctuationModel, }) = Q.cache.training_expected_value -function clever_covariate_offset_and_weights(Ψ, Q, X; - ps_lowerbound=1e-8, - weighted_fluctuation=false - ) - offset = compute_offset(MLJBase.predict(Q, X)) - covariate, weights = TMLE.clever_covariate_and_weights( - Ψ, X; - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation +function clever_covariate_offset_and_weights(model::FluctuationModel, X) + Ψ = model.Ψ + Q⁰, G⁰ = splitQG(model.initial_factors, Ψ) + offset = compute_offset(MLJBase.predict(Q⁰, X)) + covariate, weights = clever_covariate_and_weights( + Ψ, G⁰, X; + ps_lowerbound=model.ps_lowerbound, + weighted_fluctuation=model.weighted ) Xfluct = fluctuation_input(covariate, offset) return Xfluct, weights end function MLJBase.fit(model::FluctuationModel, verbosity, X, y) - Ψ = model.Ψ - Q = get_outcome_model(Ψ) clever_covariate_and_offset, weights = - clever_covariate_offset_and_weights( - Ψ, Q, X; - ps_lowerbound=model.ps_lowerbound, - weighted_fluctuation=model.weighted - ) + clever_covariate_offset_and_weights(model, X) mach = machine( one_dimensional_path(model), clever_covariate_and_offset, @@ -85,13 +109,7 @@ function MLJBase.fit(model::FluctuationModel, verbosity, X, y) end function MLJBase.predict(model::FluctuationModel, fitresult, X) - Ψ = model.Ψ - Q = get_outcome_model(Ψ) clever_covariate_and_offset, weights = - clever_covariate_offset_and_weights( - Ψ, Q, X; - ps_lowerbound=model.ps_lowerbound, - weighted_fluctuation=model.weighted - ) + clever_covariate_offset_and_weights(model, X) return MLJBase.predict(fitresult.one_dimensional_path, clever_covariate_and_offset) end \ No newline at end of file diff --git a/src/counterfactual_mean_based/gradient.jl b/src/counterfactual_mean_based/gradient.jl index 230dc59e..a956c1cd 100644 --- a/src/counterfactual_mean_based/gradient.jl +++ b/src/counterfactual_mean_based/gradient.jl @@ -8,19 +8,18 @@ For the ATE with binary treatment, confounded by W it is equal to: ``ctf_agg = Q(1, W) - Q(0, W)`` """ -function counterfactual_aggregate(Ψ::CMCompositeEstimand, Q::Machine) - X = get_outcome_datas(Ψ)[1] - Ttemplate = treatments(X, Ψ) +function counterfactual_aggregate(Ψ::CMCompositeEstimand, Q::ConditionalDistribution) + X = Q.machine.data[1] + Ttemplate = selectcols(X, treatments(Ψ)) n = nrows(Ttemplate) ctf_agg = zeros(n) # Loop over Treatment settings for (vals, sign) in indicator_fns(Ψ) # Counterfactual dataset for a given treatment setting - T_ct = counterfactualTreatment(vals, Ttemplate) - X_ct = selectcols(merge(X, T_ct), keys(X)) - # Counterfactual predictions with F - ŷ = predict(Q, X_ct) - ctf_agg .+= sign .* expected_value(ŷ) + T_ct = TMLE.counterfactualTreatment(vals, Ttemplate) + X_ct = merge(X, T_ct) + # Counterfactual mean + ctf_agg .+= sign .* expected_value(Q, X_ct) end return ctf_agg end @@ -39,18 +38,20 @@ end This part of the gradient is evaluated on the original dataset. All quantities have been precomputed and cached. """ -function ∇YX(Ψ::CMCompositeEstimand, Q::Machine; ps_lowerbound=1e-8) - X, y = get_outcome_datas(Ψ) - H = weighted_covariate(Q, Ψ, X; ps_lowerbound=ps_lowerbound) +function ∇YX(Ψ::CMCompositeEstimand, Q, G; ps_lowerbound=1e-8) + Qmach = Q.machine + X, y = Qmach.data + H = weighted_covariate(Qmach, G, Ψ, X; ps_lowerbound=ps_lowerbound) y = float(y) - gradient_Y_X_fluct = H .* (y .- training_expected_value(Q)) + gradient_Y_X_fluct = H .* (y .- training_expected_value(Qmach)) return gradient_Y_X_fluct end -function gradient_and_estimate(Ψ::CMCompositeEstimand, Q::Machine; ps_lowerbound=1e-8) +function gradient_and_estimate(Ψ::CMCompositeEstimand, factors; ps_lowerbound=1e-8) + Q, G = TMLE.splitQG(factors, Ψ) ctf_agg = counterfactual_aggregate(Ψ, Q) Ψ̂ = mean(ctf_agg) - IC = ∇YX(Ψ, Q; ps_lowerbound = ps_lowerbound) .+ ∇W(ctf_agg, Ψ̂) + IC = ∇YX(Ψ, Q, G; ps_lowerbound = ps_lowerbound) .+ ∇W(ctf_agg, Ψ̂) return IC, Ψ̂ end \ No newline at end of file diff --git a/src/counterfactual_mean_based/offset_and_covariate.jl b/src/counterfactual_mean_based/offset_and_covariate.jl index 4a903db0..24271889 100644 --- a/src/counterfactual_mean_based/offset_and_covariate.jl +++ b/src/counterfactual_mean_based/offset_and_covariate.jl @@ -1,20 +1,18 @@ -data_adaptive_ps_lower_bound(n::Int; max_lb=0.1) = - min(5 / (√(n)*log(n/5)), max_lb) - """ data_adaptive_ps_lower_bound(Ψ::CMCompositeEstimand) This startegy is from [this paper](https://academic.oup.com/aje/article/191/9/1640/6580570?login=false) but the study does not show strictly better behaviour of the strategy so not a default for now. """ -function data_adaptive_ps_lower_bound(Ψ::CMCompositeEstimand; max_lb=0.1) - n = nrows(get_outcome_datas(Ψ)[2]) - return data_adaptive_ps_lower_bound(n; max_lb=max_lb) -end +data_adaptive_ps_lower_bound(n::Int; max_lb=0.1) = + min(5 / (√(n)*log(n/5)), max_lb) + +ps_lower_bound(n::Int, lower_bound::Nothing; max_lb=0.1) = data_adaptive_ps_lower_bound(n; max_lb=max_lb) +ps_lower_bound(n::Int, lower_bound; max_lb=0.1) = min(max_lb, lower_bound) -ps_lower_bound(Ψ::CMCompositeEstimand, lower_bound::Nothing; max_lb=0.1) = data_adaptive_ps_lower_bound(Ψ; max_lb=max_lb) -ps_lower_bound(Ψ::CMCompositeEstimand, lower_bound; max_lb=0.1) = min(max_lb, lower_bound) +treatments(dataset, G) = selectcols(dataset, keys(G)) +confounders(dataset, G) = (;(key => selectcols(dataset, keys(cd.machine.data[1])) for (key, cd) in zip(keys(G), G))...) function truncate!(v::AbstractVector, ps_lowerbound::AbstractFloat) @@ -31,18 +29,14 @@ end compute_offset(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = expected_value(ŷ) compute_offset(ŷ::AbstractVector{<:Real}) = expected_value(ŷ) -compute_offset(Ψ::CMCompositeEstimand) = - compute_offset(MLJBase.predict(get_outcome_model(Ψ))) -function balancing_weights(Ψ::CMCompositeEstimand, W, T; ps_lowerbound=1e-8) - density = ones(nrows(T)) - for colname ∈ Tables.columnnames(T) - mach = Ψ.scm[colname].mach - ŷ = MLJBase.predict(mach, W[colname]) - density .*= pdf.(ŷ, Tables.getcolumn(T, colname)) +function balancing_weights(Gs, dataset; ps_lowerbound=1e-8) + jointlikelihood = ones(nrows(dataset)) + for G ∈ Gs + jointlikelihood .*= likelihood(G, dataset) end - truncate!(density, ps_lowerbound) - return 1. ./ density + truncate!(jointlikelihood, ps_lowerbound) + return 1. ./ jointlikelihood end """ @@ -62,12 +56,11 @@ if `weighted_fluctuation = true`: where SpecialIndicator(t) is defined in `indicator_fns`. """ -function clever_covariate_and_weights(Ψ::CMCompositeEstimand, X; ps_lowerbound=1e-8, weighted_fluctuation=false) +function clever_covariate_and_weights(Ψ::CMCompositeEstimand, Gs, dataset; ps_lowerbound=1e-8, weighted_fluctuation=false) # Compute the indicator values - T = treatments(X, Ψ) - W = confounders(X, Ψ) + T = treatments(dataset, Gs) indic_vals = indicator_values(indicator_fns(Ψ), T) - weights = balancing_weights(Ψ, W, T, ps_lowerbound=ps_lowerbound) + weights = balancing_weights(Gs, dataset; ps_lowerbound=ps_lowerbound) if weighted_fluctuation return indic_vals, weights end @@ -89,7 +82,7 @@ weighted_covariate(Q::Machine{<:FluctuationModel, }, args...; kwargs...) = Q.cac Computes the weighted covariate for the gradient. """ -function weighted_covariate(Q::Machine, Ψ, X; ps_lowerbound=1e-8) - H, weights = clever_covariate_and_weights(Ψ, X; ps_lowerbound=ps_lowerbound) +function weighted_covariate(Q::Machine, G, Ψ, X; ps_lowerbound=1e-8) + H, weights = clever_covariate_and_weights(Ψ, G, X; ps_lowerbound=ps_lowerbound) return H .* weights end diff --git a/src/distribution_factors.jl b/src/distribution_factors.jl index 0b4088bc..d4a8f345 100644 --- a/src/distribution_factors.jl +++ b/src/distribution_factors.jl @@ -2,12 +2,12 @@ abstract type DistributionFactor end function key(f::DistributionFactor) end -mutable struct ConditionalDistribution <: DistributionFactor +mutable struct ConditionalDistribution{T <: MLJBase.Supervised} <: DistributionFactor outcome::Symbol parents::Set{Symbol} - model::MLJBase.Supervised + model::T machine::MLJBase.Machine - function ConditionalDistribution(outcome, parents, model) + function ConditionalDistribution{T}(outcome, parents, model::T) where T <: MLJBase.Supervised outcome = Symbol(outcome) parents = Set(Symbol(x) for x in parents) outcome ∉ parents || throw(SelfReferringEquationError(outcome)) @@ -15,6 +15,9 @@ mutable struct ConditionalDistribution <: DistributionFactor end end +ConditionalDistribution(outcome, parents, model::T) where T <: MLJBase.Supervised = + ConditionalDistribution{T}(outcome, parents, model) + ConditionalDistribution(;outcome, parents, model) = ConditionalDistribution(outcome, parents, model) @@ -30,12 +33,14 @@ update_message(cd::ConditionalDistribution) = string("Reusing or Updating Condit Base.show(io::IO, ::MIME"text/plain", cd::ConditionalDistribution) = println(io, string_repr(cd)) +get_featurenames(cd::ConditionalDistribution) = sort(collect(cd.parents)) + function MLJBase.fit!(cd::ConditionalDistribution, dataset; cache=true, verbosity=1, force=false) # Never fitted or model has changed dofit = !isdefined(cd, :machine) || cd.model != cd.machine.model if dofit verbosity >= 1 && @info(fit_message(cd)) - featurenames = collect(cd.parents) + featurenames = get_featurenames(cd) data = nomissing(dataset, vcat(featurenames, cd.outcome)) X = selectcols(data, featurenames) y = Tables.getcolumn(data, cd.outcome) @@ -47,4 +52,58 @@ function MLJBase.fit!(cd::ConditionalDistribution, dataset; cache=true, verbosit verbosity >= 1 && @info(update_message(cd)) MLJBase.fit!(cd.machine, verbosity=verbosity-1, force=force) end -end \ No newline at end of file +end + +MLJBase.fitted_params(cd::ConditionalDistribution) = fitted_params(cd.machine) + +function MLJBase.predict(cd::ConditionalDistribution, dataset) + featurenames = get_featurenames(cd) + X = selectcols(dataset, featurenames) + return predict(cd.machine, X) +end + +function expected_value(factor::ConditionalDistribution, dataset) + return expected_value(predict(factor, dataset)) +end + +function likelihood(factor::ConditionalDistribution, dataset) + ŷ = predict(factor, dataset) + y = Tables.getcolumn(dataset, factor.outcome) + return pdf.(ŷ, y) +end + +## CVCounterPart + +mutable struct CVConditionalDistribution + outcome::Symbol + parents::Set{Symbol} + model::MLJBase.Supervised + resampling::ResamplingStrategy + machines::Vector{MLJBase.Machine} +end + +function MLJBase.fit!(factor::CVConditionalDistribution, dataset; cache=true, verbosity=1, force=false) + verbosity >= 1 && @info(fit_message(factor)) + featurenames = get_featurenames(cd) + data = nomissing(dataset, vcat(featurenames, cd.outcome)) + X = selectcols(data, featurenames) + y = Tables.getcolumn(data, cd.outcome) + machines = Machine[] + for (train, val) in MLJBase.train_test_pairs(resampling, nrows(data), data) + Xtrain = selectrows(X, train) + ytrain = selectrows(y, train) + mach = machine(factor.model, Xtrain, ytrain) + fit!(mach, verbosity=verbosity-1) + push!(factor.machines, mach) + end + factor.machines = machines +end + +function MLJBase.predict(factor::CVConditionalDistribution, dataset) + ŷ = Vector{Any}(undef, size(dataset, 1)) + for (fold, (_, val)) in enumerate(MLJBase.train_test_pairs(resampling, nrows(data), data)) + Xval = selectrows(X, val) + ŷ[val] = predict(factor.machines[fold], Xval) + end + return ŷ +end diff --git a/src/estimands.jl b/src/estimands.jl index b47b3205..eedd3627 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -7,9 +7,6 @@ A Estimand is a functional on distribution space Ψ: ℳ → ℜ. abstract type Estimand end treatments(Ψ::Estimand) = collect(keys(Ψ.treatment)) -treatments(dataset, Ψ::Estimand) = selectcols(dataset, treatments(Ψ)) - -confounders(dataset, Ψ::Estimand) = (;(T => selectcols(dataset, keys(Ψ.scm[T].mach.data[1])) for T in treatments(Ψ))...) AbsentLevelError(treatment_name, key, val, levels) = ArgumentError(string( "The treatment variable ", treatment_name, "'s, '", key, "' level: '", val, @@ -69,7 +66,7 @@ function estimand_key end """ Retrieves the relevant factors of the distribution to be fitted. """ -relevant_factors(Ψ::Estimand; adjustment_method::AdjustmentMethod) +get_relevant_factors(Ψ::Estimand; adjustment_method::AdjustmentMethod) """ optimize_ordering!(estimands::Vector{<:Estimand}) diff --git a/test/counterfactual_mean_based/fluctuation.jl b/test/counterfactual_mean_based/fluctuation.jl index b968bd81..613d86f2 100644 --- a/test/counterfactual_mean_based/fluctuation.jl +++ b/test/counterfactual_mean_based/fluctuation.jl @@ -39,14 +39,35 @@ end W₂ = rand(7), W₃ = rand(7) ) - fit!(Ψ, dataset, verbosity=0) - + initial_factors = fit!(Ψ, dataset, verbosity=0) ps_lowerbound = 1e-8 weighted_fluctuation = true - test_fluctuation_fit(Ψ, ps_lowerbound, weighted_fluctuation) + targeted_factors = TMLE.fluctuate(initial_factors, Ψ; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation, + verbosity=1, + tol=nothing + ) + @test targeted_factors.T === initial_factors.T + @test targeted_factors.Y !== initial_factors.Y + @test targeted_factors.Y.machine.cache.weighted_covariate == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] + Q = targeted_factors.Y + X = Q.machine.data[1] + Xfluct, weights = TMLE.clever_covariate_offset_and_weights(Q.model, X) + @test weights == [1.75, 3.5, 7.0, 1.75, 1.75, 3.5, 1.75] + @test targeted_factors.Y.machine.cache.weighted_covariate == Xfluct.covariate .* weights weighted_fluctuation = false - test_fluctuation_fit(Ψ, ps_lowerbound, weighted_fluctuation) + targeted_factors = TMLE.fluctuate(initial_factors, Ψ; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation, + verbosity=1, + tol=nothing + ) + Q = targeted_factors.Y + X = Q.machine.data[1] + Xfluct, weights = TMLE.clever_covariate_offset_and_weights(Q.model, X) + @test weights == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] end @testset "Test Fluctuation with 2 Treatments" begin @@ -68,14 +89,24 @@ end ps_lowerbound = 1e-8 weighted_fluctuation = false - fit!(Ψ, dataset, verbosity=0) - test_fluctuation_fit(Ψ, ps_lowerbound, weighted_fluctuation) + initial_factors = fit!(Ψ, dataset, verbosity=0) + targeted_factors = TMLE.fluctuate(initial_factors, Ψ; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation, + verbosity=1, + tol=nothing + ) + @test targeted_factors.T₁ === initial_factors.T₁ + @test targeted_factors.T₂ === initial_factors.T₂ + @test targeted_factors.Y !== initial_factors.Y + @test targeted_factors.Y.machine.cache.weighted_covariate ≈ + [2.45, -3.27, -3.27, 2.45, 2.45, -6.13, 8.17] atol=0.01 end @testset "Test Fluctuation with 3 Treatments" begin ## Third case: 3 Treatment variables Ψ = IATE( - outcome =:Y, + outcome = :Y, treatment=(T₁=(case="a", control="b"), T₂=(case=1, control=2), T₃=(case=true, control=false)), @@ -93,8 +124,19 @@ end ps_lowerbound = 1e-8 weighted_fluctuation = false - fit!(Ψ, dataset, verbosity=0) - test_fluctuation_fit(Ψ, ps_lowerbound, weighted_fluctuation) + initial_factors = fit!(Ψ, dataset, verbosity=0) + targeted_factors = TMLE.fluctuate(initial_factors, Ψ; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation, + verbosity=1, + tol=nothing + ) + @test targeted_factors.T₁ === initial_factors.T₁ + @test targeted_factors.T₂ === initial_factors.T₂ + @test targeted_factors.T₃ === initial_factors.T₃ + @test targeted_factors.Y !== initial_factors.Y + @test targeted_factors.Y.machine.cache.weighted_covariate ≈ + [0.0, 8.58, -21.44, 8.58, 0.0, -4.29, -4.29] atol=0.01 end @testset "Test fluctuation_input" begin @@ -109,7 +151,6 @@ end X = TMLE.fluctuation_input([1.f0, 2.f0], [1., 2.]) @test X.covariate isa Vector{Float32} @test X.offset isa Vector{Float32} - end end diff --git a/test/counterfactual_mean_based/gradient.jl b/test/counterfactual_mean_based/gradient.jl index 378409a8..33ca8f0b 100644 --- a/test/counterfactual_mean_based/gradient.jl +++ b/test/counterfactual_mean_based/gradient.jl @@ -34,30 +34,29 @@ end outcome_model = with_encoder(InteractionTransformer(order=2) |> LinearRegressor()) ) dataset = one_treatment_dataset(;n=100) - fit!(Ψ, dataset, verbosity=0) + factors = fit!(Ψ, dataset, verbosity=0) # Retrieve outcome model - Q = TMLE.get_outcome_model(Ψ) + Q = factors.Y + G = (T=factors.T, ) linear_model = fitted_params(Q).deterministic_pipeline.linear_regressor intercept = linear_model.intercept coefs = Dict(linear_model.coefs) # Counterfactual aggregate - ctf_agg = TMLE.counterfactual_aggregate(Ψ, Q) + ctf_agg = TMLE.counterfactual_aggregate(Ψ, Q, G) expected_ctf_agg = (intercept .+ coefs[:T] .+ dataset.W.*coefs[:W] .+ dataset.W.*coefs[:T_W]) .- (intercept .+ dataset.W.*coefs[:W]) @test ctf_agg ≈ expected_ctf_agg atol=1e-10 - # Retrieve propensity score model - G = Ψ.scm[:T].mach # Gradient Y|X - H = 1 ./ pdf.(predict(G), dataset.T) .* [t == 1 ? 1. : -1. for t in dataset.T] - expected_∇YX = H .* (dataset.Y .- predict(Q)) - ∇YX = TMLE.∇YX(Ψ, Q; ps_lowerbound=ps_lowerbound) + H = 1 ./ pdf.(predict(G.T.machine), dataset.T) .* [t == 1 ? 1. : -1. for t in dataset.T] + expected_∇YX = H .* (dataset.Y .- predict(Q.machine)) + ∇YX = TMLE.∇YX(Ψ, Q, G; ps_lowerbound=ps_lowerbound) @test expected_∇YX == ∇YX # Gradient W expectedΨ̂ = mean(ctf_agg) ∇W = TMLE.∇W(ctf_agg, expectedΨ̂) @test ∇W == ctf_agg .- expectedΨ̂ # gradient_and_estimate - IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, Q; ps_lowerbound=ps_lowerbound) + IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, factors; ps_lowerbound=ps_lowerbound) @test expectedΨ̂ == Ψ̂ @test IC == ∇YX .+ ∇W end diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index 509a4780..bcc535ba 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -25,10 +25,10 @@ using MLJGLMInterface treatment_model = LinearBinaryClassifier() ) - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, ps_lowerbound=0.025, verbosity=0) + tmle_result, targeted_factors = tmle!(Ψ, dataset; ps_lowerbound=0.025, verbosity=0) # TMLE @test estimate(tmle_result) ≈ -0.185533 atol = 1e-6 - @test naive_plugin_estimate(Ψ) ≈ -0.150078 atol = 1e-6 + @test naive_plugin_estimate!(Ψ, dataset) ≈ -0.150078 atol = 1e-6 l, u = confint(OneSampleTTest(tmle_result)) @test l ≈ -0.279246 atol = 1e-6 @test u ≈ -0.091821 atol = 1e-6 From 0c979e4174f6fa9abde073ce9e7d6d32d5785cbb Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 5 Sep 2023 22:19:59 +0100 Subject: [PATCH 064/151] add work in progress --- README.md | 7 + sandbox.R | 121 +++++++ src/TMLE.jl | 3 +- src/counterfactual_mean_based/estimands.jl | 28 +- src/counterfactual_mean_based/estimators.jl | 83 +---- src/counterfactual_mean_based/fluctuation.jl | 81 ++--- src/counterfactual_mean_based/gradient.jl | 22 +- .../offset_and_covariate.jl | 2 +- src/distribution_factors.jl | 109 ------ src/estimands.jl | 74 +++- src/estimate.jl | 163 --------- src/estimates.jl | 342 ++++++++++++++++++ src/scm.jl | 72 ++-- src/utils.jl | 2 +- .../non_regression_test.jl | 32 +- 15 files changed, 674 insertions(+), 467 deletions(-) create mode 100644 sandbox.R delete mode 100644 src/distribution_factors.jl delete mode 100644 src/estimate.jl create mode 100644 src/estimates.jl diff --git a/README.md b/README.md index 1740c70a..982a0f31 100644 --- a/README.md +++ b/README.md @@ -8,3 +8,10 @@ This package enables the estimation of various Causal Inference related estimands using Targeted Minimum Loss-Based Estimation (TMLE). TMLE.jl is based on top of [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/), which means any MLJ compliant machine-learning model can be used here. **New to TMLE.jl?** [Get started now](https://targene.github.io/TMLE.jl/stable/) + +## Notes to self + +- Factors should probably not be exposed. For instance resampling should only be specified once per estimation procedure and not for each factor. +- For multiple TMLE steps, convergence when the mean of the IC is below: σ/n where σ is the standard deviation of the IC. +- Try make selectcols preserve input type +- Watch for discrepancy in train_validation_indices with missing data... \ No newline at end of file diff --git a/sandbox.R b/sandbox.R new file mode 100644 index 00000000..e6917201 --- /dev/null +++ b/sandbox.R @@ -0,0 +1,121 @@ +library(data.table) +library(tmle3) +library(sl3) +data <- fread( + paste0( + "https://raw.githubusercontent.com/tlverse/tlverse-data/master/", + "wash-benefits/washb_data.csv" + ), + stringsAsFactors = TRUE +) + +node_list <- list( + W = c( + "month", "aged", "sex", "momage", "momedu", + "momheight", "hfiacat", "Nlt18", "Ncomp", "watmin", + "elec", "floor", "walls", "roof", "asset_wardrobe", + "asset_table", "asset_chair", "asset_khat", + "asset_chouki", "asset_tv", "asset_refrig", + "asset_bike", "asset_moto", "asset_sewmach", + "asset_mobile" + ), + A = "tr", + Y = "whz" +) + +processed <- process_missing(data, node_list) +data <- processed$data +node_list <- processed$node_list + +ate_spec <- tmle_ATE( + treatment_level = "Nutrition + WSH", + control_level = "Control" +) +tmle_spec = ate_spec + +# choose base learners +lrnr_mean <- make_learner(Lrnr_mean) +lrnr_rf <- make_learner(Lrnr_ranger) + +# define metalearners appropriate to data types +ls_metalearner <- make_learner(Lrnr_nnls) +mn_metalearner <- make_learner( + Lrnr_solnp, metalearner_linear_multinomial, + loss_loglik_multinomial +) +sl_Y <- Lrnr_sl$new( + learners = list(lrnr_mean, lrnr_rf), + metalearner = ls_metalearner +) +sl_A <- Lrnr_sl$new( + learners = list(lrnr_mean, lrnr_rf), + metalearner = mn_metalearner +) +learner_list <- list(A = sl_A, Y = sl_Y) + +tmle_task <- tmle_spec$make_tmle_task(data, node_list) + +initial_likelihood <- tmle_spec$make_initial_likelihood(tmle_task, learner_list) + +updater <- tmle_spec$make_updater() +targeted_likelihood <- tmle_spec$make_targeted_likelihood(initial_likelihood, updater) + +tmle_params <- tmle_spec$make_params(tmle_task, targeted_likelihood) +updater$tmle_params <- tmle_params + +fit <- fit_tmle3(tmle_task, targeted_likelihood, tmle_params, updater) + +update_fold <- updater$update_fold +maxit <- 100 + +# seed current estimates +current_estimates <- lapply(updater$tmle_params, function(tmle_param) { + tmle_param$estimates(tmle_task, update_fold) +}) + +likelihood = initial_likelihood +likelihood_values <- likelihood$cache$get_values(A_factor, tmle_task, fold_number) +likelihood_values <- A_factor$get_likelihood(tmle_task, fold_number) + +update_node <- updater$update_nodes[[1]] +submodel_data <- updater$generate_submodel_data( + likelihood, tmle_task, + fold_number, update_node, + drop_censored = TRUE + ) + +for (steps in seq_len(maxit)) { + updater$update_step(likelihood, tmle_task, update_fold) + + # update estimates based on updated likelihood + private$.current_estimates <- lapply(self$tmle_params, function(tmle_param) { + tmle_param$estimates(tmle_task, update_fold) + }) + + if (self$check_convergence(tmle_task, update_fold)) { + break + } + + if (self$use_best) { + self$update_best(likelihood) + } +} + +if (self$use_best) { + self$update_best(likelihood) + likelihood$cache$set_best() +} + +private$.steps <- self$updater$steps + + +estimates <- lapply( + self$tmle_params, + function(tmle_param) { + tmle_param$estimates(self$tmle_task, self$updater$update_fold) + } +) + +private$.estimates <- estimates +private$.ED <- ED_from_estimates(estimates) + diff --git a/src/TMLE.jl b/src/TMLE.jl index 6a56f4bd..a9a745ca 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -44,10 +44,9 @@ export BackdoorAdjustment # ############################################################################# include("utils.jl") -include("distribution_factors.jl") include("scm.jl") include("estimands.jl") -include("estimate.jl") +include("estimates.jl") include("treatment_transformer.jl") include("adjustment.jl") diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index bea1ec9d..b7097af5 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -1,5 +1,5 @@ ##################################################################### -### Counterfactual Mean ### +### Counterfactual Mean ### ##################################################################### """ @@ -195,26 +195,10 @@ function check_parameter_against_scm(scm::SCM, outcome, treatment) end end -function get_relevant_factors(Ψ::CMCompositeEstimand; adjustment_method=BackdoorAdjustment(), verbosity=1) - outcome_factor = get_or_set_conditional_distribution_from_natural!( - Ψ.scm, outcome(Ψ), - outcome_parents(adjustment_method, Ψ); - verbosity=verbosity - ) - treatment_factors = (get_conditional_distribution(Ψ.scm, treatment, parents(Ψ.scm, treatment)) for treatment in treatments(Ψ)) - return NamedTuple{(outcome_factor.outcome, (tf.outcome for tf in treatment_factors)...)}([outcome_factor, treatment_factors...]) -end - -function MLJBase.fit!(Ψ::CMCompositeEstimand, dataset; adjustment_method=BackdoorAdjustment(), verbosity=1, force=false, cache=true) - relevant_factors = get_relevant_factors(Ψ; adjustment_method=adjustment_method, verbosity=verbosity) - for factor in relevant_factors - fit!(factor, dataset; - cache=cache, - verbosity=verbosity, - force=force - ) - end - return relevant_factors +function get_relevant_factors(Ψ::CMCompositeEstimand; adjustment_method=BackdoorAdjustment()) + outcome_model = ExpectedValue(Ψ.scm, outcome(Ψ), outcome_parents(adjustment_method, Ψ)) + treatment_factors = Tuple(ConditionalDistribution(Ψ.scm, treatment, parents(Ψ.scm, treatment)) for treatment in treatments(Ψ)) + return CMRelevantFactors(Ψ.scm, outcome_model, treatment_factors) end function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: CMCompositeEstimand @@ -241,4 +225,4 @@ function estimand_key(Ψ::CMCompositeEstimand) join(treatments(Ψ), "_"), string(outcome(Ψ)), ) -end \ No newline at end of file +end diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 2803cc74..135e0504 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -1,56 +1,41 @@ -""" - tmle!(Ψ::CMCompositeEstimand, dataset; - adjustment_method=BackdoorAdjustment(), - verbosity=1, - force=false, - ps_lowerbound=1e-8, - weighted_fluctuation=false - ) +choose_initial_dataset(dataset, nomissing_dataset, resampling::Nothing) = dataset +choose_initial_dataset(dataset, nomissing_dataset, resampling) = nomissing_dataset -Performs Targeted Minimum Loss Based Estimation of the target estimand. - -## Arguments - -- Ψ: An estimand of interest. -- dataset: A table respecting the `Tables.jl` interface. -- adjustment_method: A confounding adjustment method. -- verbosity: Level of logging. -- force: To force refit of machines in the SCM . -- ps_lowerbound: The propensity score will be truncated to respect this lower bound. -- weighted_fluctuation: To use a weighted fluctuation instead of the vanilla TMLE, can improve stability. -""" -function tmle!(Ψ::CMCompositeEstimand, dataset; +function tmle!(Ψ::CMCompositeEstimand, models, dataset; resampling=nothing, adjustment_method=BackdoorAdjustment(), verbosity=1, force=false, ps_lowerbound=1e-8, - weighted_fluctuation=false + weighted_fluctuation=false, + factors_cache=nothing ) # Check the estimand against the dataset - check_treatment_levels(Ψ, dataset) + TMLE.check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's equations verbosity >= 1 && @info "Fitting the required equations..." - initial_factors = fit!(Ψ, dataset; - adjustment_method=adjustment_method, - verbosity=verbosity, - force=force + relevant_factors = TMLE.get_relevant_factors(Ψ; adjustment_method=adjustment_method) + nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(relevant_factors)) + initial_factors_dataset = TMLE.choose_initial_dataset(dataset, nomissing_dataset, resampling) + initial_factors_estimate = TMLE.estimate(relevant_factors, resampling, models, initial_factors_dataset; + factors_cache=factors_cache, + verbosity=verbosity ) # Get propensity score truncation threshold - n = nrows(initial_factors[outcome(Ψ)].machine.data[2]) - ps_lowerbound = ps_lower_bound(n, ps_lowerbound) + n = nrows(nomissing_dataset) + ps_lowerbound = TMLE.ps_lower_bound(n, ps_lowerbound) # Fit Fluctuation verbosity >= 1 && @info "Performing TMLE..." - targeted_factors = fluctuate(initial_factors, Ψ; + targeted_factors_estimate = TMLE.fluctuate(initial_factors_estimate, Ψ, nomissing_dataset; tol=nothing, verbosity=verbosity, weighted_fluctuation=weighted_fluctuation, ps_lowerbound=ps_lowerbound ) # Estimation results after TMLE - IC, Ψ̂ = gradient_and_estimate(Ψ, targeted_factors; ps_lowerbound=ps_lowerbound) + IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, targeted_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) verbosity >= 1 && @info "Done." - return TMLEstimate(Ψ̂, IC), targeted_factors + return TMLEstimate(Ψ̂, IC), targeted_factors_estimate end function ose!(Ψ::CMCompositeEstimand, dataset; @@ -95,36 +80,4 @@ function naive_plugin_estimate!(Ψ::CMCompositeEstimand, dataset; force=force ) return mean(counterfactual_aggregate(Ψ, Q)) -end - -""" - -# Algorithm - -1. Split the dataset in v-folds splits -2. For each split, fit the relevant factors of the distribution on the training sets -3. Perform cross-validated selection of the fluctuation model - a. For each split compute the clever covariate and offset on the validation sets - b. Fit a fluctuate on the full validation set - c. Iterate until epsilon = 0 (usually one step in our case) -4. Compute the targeted estimate - a. For each validation set compute the estimate from the fluctuated factors - b. Average those estimates to retrieve the targeted estimate -5. Compute the estimated variance - a. For each validation set compute the variance of the IC from the fluctuated factors - b. Average those estimates to retrieve the IC variance and divide per root n - -Discussion points: - -- More of a CV question, why does the cross-validated selector solves the optimization problem presented in the paper? -""" -function cvtmle!(Ψ::CMCompositeEstimand, dataset; - resampling=nothing, - adjustment_method=BackdoorAdjustment(), - verbosity=1, - force=false, - ps_lowerbound=1e-8, - weighted_fluctuation=false) -end - - +end \ No newline at end of file diff --git a/src/counterfactual_mean_based/fluctuation.jl b/src/counterfactual_mean_based/fluctuation.jl index 66599467..8384667a 100644 --- a/src/counterfactual_mean_based/fluctuation.jl +++ b/src/counterfactual_mean_based/fluctuation.jl @@ -1,67 +1,40 @@ - -mutable struct ContinuousFluctuation <: MLJBase.Probabilistic - Ψ::CMCompositeEstimand - initial_factors::NamedTuple - tol::Union{Nothing, Float64} - ps_lowerbound::Float64 - weighted::Bool -end - -mutable struct BinaryFluctuation <: MLJBase.Probabilistic +mutable struct Fluctuation <: MLJBase.Supervised Ψ::CMCompositeEstimand - initial_factors::NamedTuple + initial_factors::TMLE.MLCMRelevantFactors tol::Union{Nothing, Float64} ps_lowerbound::Float64 weighted::Bool end -FluctuationModel = Union{ContinuousFluctuation, BinaryFluctuation} - -function Fluctuation(Ψ, initial_factors; tol=nothing, ps_lowerbound=1e-8, weighted=false) - outcome_scitype = scitype(initial_factors[outcome(Ψ)].machine.data[2]) - if outcome_scitype <: AbstractVector{<:MLJBase.Continuous} - return ContinuousFluctuation(Ψ, initial_factors, tol, ps_lowerbound, weighted) - elseif outcome_scitype <: AbstractVector{<:Finite} - return BinaryFluctuation(Ψ, initial_factors, tol, ps_lowerbound, weighted) - else - throw(ArgumentError("Cannot proceed with outcome with target_scitype: $outcome_scitype")) - end -end +Fluctuation(Ψ, initial_factors; tol=nothing, ps_lowerbound=1e-8, weighted=false) = + Fluctuation(Ψ, initial_factors, tol, ps_lowerbound, weighted) -function fluctuate(initial_factors, Ψ; +function fluctuate(initial_factors_estimate::MLCMRelevantFactors, Ψ, dataset; tol=nothing, verbosity=1, weighted_fluctuation=false, - ps_lowerbound=1e-8 + ps_lowerbound=1e-8, + factors_cache=nothing ) - Qfluct_model = Fluctuation(Ψ, initial_factors; + Qfluct_model = TMLE.Fluctuation(Ψ, initial_factors_estimate; tol=tol, ps_lowerbound=ps_lowerbound, weighted=weighted_fluctuation ) - Q⁰, G⁰ = splitQG(initial_factors, Ψ) - # Retrieve dataset from Q⁰ - X, y = Q⁰.machine.data - dataset = merge(X, NamedTuple{(Q⁰.outcome, )}([y])) - Qfluct = ConditionalDistribution( - Q⁰.outcome, - Q⁰.parents, - Qfluct_model + fluctuated_outcome_mean = TMLE.estimate( + initial_factors_estimate.outcome_mean.estimand, + dataset, + Qfluct_model, + nothing; + factors_cache=factors_cache, + verbosity=verbosity ) - fit!(Qfluct, dataset, verbosity=verbosity-1) - return merge(NamedTuple{(Q⁰.outcome,)}([Qfluct]), G⁰) + fluctuated_propensity_score = initial_factors_estimate.propensity_score + return MLCMRelevantFactors(initial_factors_estimate.estimand, fluctuated_outcome_mean, fluctuated_propensity_score) end -function splitQG(factors, Ψ) - Q = factors[outcome(Ψ)] - G = (; (key => factors[key] for key in treatments(Ψ))...) - return Q, G -end - -epsilon(Qstar) = fitted_params(fitted_params(Qstar).fitresult).coef[1] - -one_dimensional_path(model::ContinuousFluctuation) = LinearRegressor(fit_intercept=false, offsetcol = :offset) -one_dimensional_path(model::BinaryFluctuation) = LinearBinaryClassifier(fit_intercept=false, offsetcol = :offset) +one_dimensional_path(target_scitype::Type{T}) where T <: AbstractVector{<:MLJBase.Continuous} = LinearRegressor(fit_intercept=false, offsetcol = :offset) +one_dimensional_path(target_scitype::Type{T}) where T <: AbstractVector{<:Finite} = LinearBinaryClassifier(fit_intercept=false, offsetcol = :offset) fluctuation_input(covariate::AbstractVector{T}, offset::AbstractVector{T}) where T = (covariate=covariate, offset=offset) @@ -72,14 +45,14 @@ The GLM models require inputs of the same type fluctuation_input(covariate::AbstractVector{T1}, offset::AbstractVector{T2}) where {T1, T2} = (covariate=covariate, offset=convert(Vector{T1}, offset)) -training_expected_value(Q::Machine{<:FluctuationModel, }) = Q.cache.training_expected_value +training_expected_value(Q::Machine{<:Fluctuation, }, dataset) = Q.cache.training_expected_value -function clever_covariate_offset_and_weights(model::FluctuationModel, X) - Ψ = model.Ψ - Q⁰, G⁰ = splitQG(model.initial_factors, Ψ) +function clever_covariate_offset_and_weights(model::Fluctuation, X) + Q⁰ = model.initial_factors.outcome_mean + G⁰ = model.initial_factors.propensity_score offset = compute_offset(MLJBase.predict(Q⁰, X)) covariate, weights = clever_covariate_and_weights( - Ψ, G⁰, X; + model.Ψ, G⁰, X; ps_lowerbound=model.ps_lowerbound, weighted_fluctuation=model.weighted ) @@ -87,11 +60,11 @@ function clever_covariate_offset_and_weights(model::FluctuationModel, X) return Xfluct, weights end -function MLJBase.fit(model::FluctuationModel, verbosity, X, y) +function MLJBase.fit(model::Fluctuation, verbosity, X, y) clever_covariate_and_offset, weights = clever_covariate_offset_and_weights(model, X) mach = machine( - one_dimensional_path(model), + one_dimensional_path(scitype(y)), clever_covariate_and_offset, y, weights, @@ -108,7 +81,7 @@ function MLJBase.fit(model::FluctuationModel, verbosity, X, y) return fitresult, cache, nothing end -function MLJBase.predict(model::FluctuationModel, fitresult, X) +function MLJBase.predict(model::Fluctuation, fitresult, X) clever_covariate_and_offset, weights = clever_covariate_offset_and_weights(model, X) return MLJBase.predict(fitresult.one_dimensional_path, clever_covariate_and_offset) diff --git a/src/counterfactual_mean_based/gradient.jl b/src/counterfactual_mean_based/gradient.jl index a956c1cd..88bf5766 100644 --- a/src/counterfactual_mean_based/gradient.jl +++ b/src/counterfactual_mean_based/gradient.jl @@ -8,8 +8,8 @@ For the ATE with binary treatment, confounded by W it is equal to: ``ctf_agg = Q(1, W) - Q(0, W)`` """ -function counterfactual_aggregate(Ψ::CMCompositeEstimand, Q::ConditionalDistribution) - X = Q.machine.data[1] +function counterfactual_aggregate(Ψ::CMCompositeEstimand, Q, dataset) + X = selectcols(dataset, featurenames(Q)) Ttemplate = selectcols(X, treatments(Ψ)) n = nrows(Ttemplate) ctf_agg = zeros(n) @@ -38,20 +38,20 @@ end This part of the gradient is evaluated on the original dataset. All quantities have been precomputed and cached. """ -function ∇YX(Ψ::CMCompositeEstimand, Q, G; ps_lowerbound=1e-8) +function ∇YX(Ψ::CMCompositeEstimand, Q, G, dataset; ps_lowerbound=1e-8) Qmach = Q.machine - X, y = Qmach.data - H = weighted_covariate(Qmach, G, Ψ, X; ps_lowerbound=ps_lowerbound) - y = float(y) - gradient_Y_X_fluct = H .* (y .- training_expected_value(Qmach)) + H = weighted_covariate(Qmach, G, Ψ, dataset; ps_lowerbound=ps_lowerbound) + y = float(Tables.getcolumn(dataset, Q.estimand.outcome)) + gradient_Y_X_fluct = H .* (y .- training_expected_value(Qmach, dataset)) return gradient_Y_X_fluct end -function gradient_and_estimate(Ψ::CMCompositeEstimand, factors; ps_lowerbound=1e-8) - Q, G = TMLE.splitQG(factors, Ψ) - ctf_agg = counterfactual_aggregate(Ψ, Q) +function gradient_and_estimate(Ψ::CMCompositeEstimand, factors, dataset; ps_lowerbound=1e-8) + Q = factors.outcome_mean + G = factors.propensity_score + ctf_agg = counterfactual_aggregate(Ψ, Q, dataset) Ψ̂ = mean(ctf_agg) - IC = ∇YX(Ψ, Q, G; ps_lowerbound = ps_lowerbound) .+ ∇W(ctf_agg, Ψ̂) + IC = ∇YX(Ψ, Q, G, dataset; ps_lowerbound = ps_lowerbound) .+ ∇W(ctf_agg, Ψ̂) return IC, Ψ̂ end \ No newline at end of file diff --git a/src/counterfactual_mean_based/offset_and_covariate.jl b/src/counterfactual_mean_based/offset_and_covariate.jl index 24271889..67dfd750 100644 --- a/src/counterfactual_mean_based/offset_and_covariate.jl +++ b/src/counterfactual_mean_based/offset_and_covariate.jl @@ -74,7 +74,7 @@ end If Q is a fluctuation and caches data, the covariate has already been computed and can be retrieved. """ -weighted_covariate(Q::Machine{<:FluctuationModel, }, args...; kwargs...) = Q.cache.weighted_covariate +weighted_covariate(Q::Machine{<:Fluctuation, }, args...; kwargs...) = Q.cache.weighted_covariate """ diff --git a/src/distribution_factors.jl b/src/distribution_factors.jl deleted file mode 100644 index d4a8f345..00000000 --- a/src/distribution_factors.jl +++ /dev/null @@ -1,109 +0,0 @@ -abstract type DistributionFactor end - -function key(f::DistributionFactor) end - -mutable struct ConditionalDistribution{T <: MLJBase.Supervised} <: DistributionFactor - outcome::Symbol - parents::Set{Symbol} - model::T - machine::MLJBase.Machine - function ConditionalDistribution{T}(outcome, parents, model::T) where T <: MLJBase.Supervised - outcome = Symbol(outcome) - parents = Set(Symbol(x) for x in parents) - outcome ∉ parents || throw(SelfReferringEquationError(outcome)) - return new(outcome, parents, model) - end -end - -ConditionalDistribution(outcome, parents, model::T) where T <: MLJBase.Supervised = - ConditionalDistribution{T}(outcome, parents, model) - -ConditionalDistribution(;outcome, parents, model) = - ConditionalDistribution(outcome, parents, model) - -key(cd::ConditionalDistribution) = (cd.outcome, cd.parents) - -cond_dist_string(outcome, parents) = string(outcome, " | ", join(parents, ", ")) - -string_repr(cd::ConditionalDistribution) = cond_dist_string(cd.outcome, cd.parents) - -fit_message(cd::ConditionalDistribution) = string("Fitting Conditional Distribution Factor: ", string_repr(cd)) - -update_message(cd::ConditionalDistribution) = string("Reusing or Updating Conditional Distribution Factor: ", string_repr(cd)) - -Base.show(io::IO, ::MIME"text/plain", cd::ConditionalDistribution) = println(io, string_repr(cd)) - -get_featurenames(cd::ConditionalDistribution) = sort(collect(cd.parents)) - -function MLJBase.fit!(cd::ConditionalDistribution, dataset; cache=true, verbosity=1, force=false) - # Never fitted or model has changed - dofit = !isdefined(cd, :machine) || cd.model != cd.machine.model - if dofit - verbosity >= 1 && @info(fit_message(cd)) - featurenames = get_featurenames(cd) - data = nomissing(dataset, vcat(featurenames, cd.outcome)) - X = selectcols(data, featurenames) - y = Tables.getcolumn(data, cd.outcome) - mach = machine(cd.model, X, y, cache=cache) - MLJBase.fit!(mach, verbosity=verbosity-1) - cd.machine = mach - # Also refit if force is true - else - verbosity >= 1 && @info(update_message(cd)) - MLJBase.fit!(cd.machine, verbosity=verbosity-1, force=force) - end -end - -MLJBase.fitted_params(cd::ConditionalDistribution) = fitted_params(cd.machine) - -function MLJBase.predict(cd::ConditionalDistribution, dataset) - featurenames = get_featurenames(cd) - X = selectcols(dataset, featurenames) - return predict(cd.machine, X) -end - -function expected_value(factor::ConditionalDistribution, dataset) - return expected_value(predict(factor, dataset)) -end - -function likelihood(factor::ConditionalDistribution, dataset) - ŷ = predict(factor, dataset) - y = Tables.getcolumn(dataset, factor.outcome) - return pdf.(ŷ, y) -end - -## CVCounterPart - -mutable struct CVConditionalDistribution - outcome::Symbol - parents::Set{Symbol} - model::MLJBase.Supervised - resampling::ResamplingStrategy - machines::Vector{MLJBase.Machine} -end - -function MLJBase.fit!(factor::CVConditionalDistribution, dataset; cache=true, verbosity=1, force=false) - verbosity >= 1 && @info(fit_message(factor)) - featurenames = get_featurenames(cd) - data = nomissing(dataset, vcat(featurenames, cd.outcome)) - X = selectcols(data, featurenames) - y = Tables.getcolumn(data, cd.outcome) - machines = Machine[] - for (train, val) in MLJBase.train_test_pairs(resampling, nrows(data), data) - Xtrain = selectrows(X, train) - ytrain = selectrows(y, train) - mach = machine(factor.model, Xtrain, ytrain) - fit!(mach, verbosity=verbosity-1) - push!(factor.machines, mach) - end - factor.machines = machines -end - -function MLJBase.predict(factor::CVConditionalDistribution, dataset) - ŷ = Vector{Any}(undef, size(dataset, 1)) - for (fold, (_, val)) in enumerate(MLJBase.train_test_pairs(resampling, nrows(data), data)) - Xval = selectrows(X, val) - ŷ[val] = predict(factor.machines[fold], Xval) - end - return ŷ -end diff --git a/src/estimands.jl b/src/estimands.jl index eedd3627..05165270 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -6,6 +6,11 @@ A Estimand is a functional on distribution space Ψ: ℳ → ℜ. """ abstract type Estimand end +string_repr(estimand::Estimand) = estimand + +Base.show(io::IO, ::MIME"text/plain", estimand::Estimand) = + println(io, string_repr(estimand)) + treatments(Ψ::Estimand) = collect(keys(Ψ.treatment)) AbsentLevelError(treatment_name, key, val, levels) = ArgumentError(string( @@ -63,11 +68,6 @@ Function used to sort estimands for optimal estimation ordering. """ function estimand_key end -""" -Retrieves the relevant factors of the distribution to be fitted. -""" -get_relevant_factors(Ψ::Estimand; adjustment_method::AdjustmentMethod) - """ optimize_ordering!(estimands::Vector{<:Estimand}) @@ -81,4 +81,66 @@ optimize_ordering!(estimands::Vector{<:Estimand}) = sort!(estimands, by=estimand See [`optimize_ordering!`](@ref) """ -optimize_ordering(estimands::Vector{<:Estimand}) = sort(estimands, by=estimand_key) \ No newline at end of file +optimize_ordering(estimands::Vector{<:Estimand}) = sort(estimands, by=estimand_key) + +##################################################################### +### Conditional Distribution ### +##################################################################### +""" +Defines a Conditional Distribution estimand ``(outcome, parents) → P(outcome|parents)``. +""" +struct ConditionalDistribution <: Estimand + scm::SCM + outcome::Symbol + parents::Set{Symbol} + function ConditionalDistribution(scm, outcome, parents) + outcome = Symbol(outcome) + parents = Set(Symbol(x) for x in parents) + outcome ∉ parents || throw(SelfReferringEquationError(outcome)) + # Maybe check variables are in the SCM? + return new(scm, outcome, parents) + end +end + +string_repr(estimand::ConditionalDistribution) = + string("P₀(", estimand.outcome, " | ", join(estimand.parents, ", "), ")") + +featurenames(estimand::ConditionalDistribution) = sort(collect(estimand.parents)) + +variables(estimand::ConditionalDistribution) = union(Set([estimand.outcome]), estimand.parents) + +estimand_key(cd::ConditionalDistribution) = (cd.outcome, cd.parents) + +##################################################################### +### ExpectedValue ### +##################################################################### + +""" +Defines an Expected Value estimand ``parents → E[Outcome|parents]``. +At the moment there is no distinction between an Expected Value and +a Conditional Distribution because they are estimated in the same way. +""" +const ExpectedValue = ConditionalDistribution + +##################################################################### +### CMRelevantFactors ### +##################################################################### + +""" +Defines relevant factors that need to be estimated in order to estimate any +Counterfactual Mean composite estimand (see `CMCompositeEstimand`). +""" +struct CMRelevantFactors <: Estimand + scm::SCM + outcome_mean::ConditionalDistribution + propensity_score::Tuple{Vararg{ConditionalDistribution}} +end + +string_repr(estimand::CMRelevantFactors) = + string("Composite Factor: \n", + "----------------\n- ", + string_repr(estimand.outcome_mean),"\n- ", + join((string_repr(f) for f in estimand.propensity_score), "\n- ")) + +variables(estimand::CMRelevantFactors) = + union(variables(estimand.outcome_mean), (variables(est) for est in estimand.propensity_score)...) diff --git a/src/estimate.jl b/src/estimate.jl deleted file mode 100644 index 7976be0d..00000000 --- a/src/estimate.jl +++ /dev/null @@ -1,163 +0,0 @@ - -abstract type EICEstimate end - -struct TMLEstimate{T<:AbstractFloat} <: EICEstimate - Ψ̂::T - IC::Vector{T} -end - -struct OSEstimate{T<:AbstractFloat} <: EICEstimate - Ψ̂::T - IC::Vector{T} -end - -struct ComposedEstimate{T<:AbstractFloat} - Ψ̂::Array{T} - σ̂::Matrix{T} - n::Int -end - -function Base.show(io::IO, ::MIME"text/plain", est::ComposedEstimate) - if length(est.σ̂) !== 1 - println(io, string("Estimate: ", estimate(est), "\nVariance: \n", var(est))) - else - testresult = OneSampleTTest(est) - data = [estimate(est) confint(testresult) pvalue(testresult);] - headers = ["Estimate", "95% Confidence Interval", "P-value"] - pretty_table(io, data;header=headers) - end -end - -function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate) - testresult = OneSampleTTest(est) - data = [estimate(est) confint(testresult) pvalue(testresult);] - pretty_table(io, data;header=["Estimate", "95% Confidence Interval", "P-value"]) -end - -""" - Distributions.estimate(r::EICEstimate) - -Retrieves the final estimate: after the TMLE step. -""" -Distributions.estimate(est::EICEstimate) = est.Ψ̂ - -""" - Distributions.estimate(r::ComposedEstimate) - -Retrieves the final estimate: after the TMLE step. -""" -Distributions.estimate(est::ComposedEstimate) = - length(est.Ψ̂) == 1 ? est.Ψ̂[1] : est.Ψ̂ - -""" - var(r::EICEstimate) - -Computes the estimated variance associated with the estimate. -""" -Statistics.var(est::EICEstimate) = - var(est.IC)/size(est.IC, 1) - -""" - var(r::ComposedEstimate) - -Computes the estimated variance associated with the estimate. -""" -Statistics.var(est::ComposedEstimate) = - length(est.σ̂) == 1 ? est.σ̂[1] / est.n : est.σ̂ ./ est.n - - -""" - OneSampleZTest(r::EICEstimate, Ψ₀=0) - -Performs a Z test on the EICEstimate. -""" -HypothesisTests.OneSampleZTest(est::EICEstimate, Ψ₀=0) = - OneSampleZTest(estimate(est), std(est.IC), size(est.IC, 1), Ψ₀) - -""" - OneSampleTTest(r::EICEstimate, Ψ₀=0) - -Performs a T test on the EICEstimate. -""" -HypothesisTests.OneSampleTTest(est::EICEstimate, Ψ₀=0) = - OneSampleTTest(estimate(est), std(est.IC), size(est.IC, 1), Ψ₀) - -""" - OneSampleTTest(r::ComposedEstimate, Ψ₀=0) - -Performs a T test on the ComposedEstimate. -""" -function HypothesisTests.OneSampleTTest(est::ComposedEstimate, Ψ₀=0) - @assert length(est.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." - return OneSampleTTest(estimate(est), sqrt(est.σ̂[1]), est.n, Ψ₀) -end - -""" - OneSampleZTest(r::ComposedEstimate, Ψ₀=0) - -Performs a T test on the ComposedEstimate. -""" -function HypothesisTests.OneSampleZTest(est::ComposedEstimate, Ψ₀=0) - @assert length(est.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." - return OneSampleZTest(estimate(est), sqrt(est.σ̂[1]), est.n, Ψ₀) -end - -""" - compose(f, estimation_results::Vararg{EICEstimate, N}) where N - -Provides an estimator of f(estimation_results...). - -# Mathematical details - -The following is a summary from `Asymptotic Statistics`, A. W. van der Vaart. - -Consider k TMLEs computed from a dataset of size n and embodied by Tₙ = (T₁,ₙ, ..., Tₖ,ₙ). -Since each of them is asymptotically normal, the multivariate CLT provides the joint -distribution: - - √n(Tₙ - Ψ₀) ↝ N(0, Σ), - -where Σ is the covariance matrix of the TMLEs influence curves. - -Let f:ℜᵏ→ℜᵐ, be a differentiable map at Ψ₀. Then, the delta method provides the -limiting distribution of √n(f(Tₙ) - f(Ψ₀)). Because Tₙ is normal, the result is: - - √n(f(Tₙ) - f(Ψ₀)) ↝ N(0, ∇f(Ψ₀) ̇Σ ̇(∇f(Ψ₀))ᵀ), - -where ∇f(Ψ₀):ℜᵏ→ℜᵐ is a linear map such that by abusing notations and identifying the -function with the multiplication matrix: ∇f(Ψ₀):h ↦ ∇f(Ψ₀) ̇h. And the matrix ∇f(Ψ₀) is -the jacobian of f at Ψ₀. - -Hence, the only thing we need to do is: -- Compute the covariance matrix Σ -- Compute the jacobian ∇f, which can be done using Julia's automatic differentiation facilities. -- The final estimator is normal with mean f₀=f(Ψ₀) and variance σ₀=∇f(Ψ₀) ̇Σ ̇(∇f(Ψ₀))ᵀ - -# Arguments - -- f: An array-input differentiable map. -- estimation_results: 1 or more `EICEstimate` structs. - -# Examples - -Assuming `res₁` and `res₂` are TMLEs: - -```julia -f(x, y) = [x^2 - y, y - 3x] -compose(f, res₁, res₂) -``` -""" -function compose(f, estimators::Vararg{EICEstimate, N}; backend=AD.ZygoteBackend()) where N - Σ = cov(estimators...) - estimates = [estimate(r) for r in estimators] - f₀, Js = AD.value_and_jacobian(backend, f, estimates...) - J = hcat(Js...) - n = size(first(estimators).IC, 1) - σ₀ = J*Σ*J' - return ComposedEstimate(collect(f₀), σ₀, n) -end - -function Statistics.cov(estimators::Vararg{EICEstimate, N}) where N - X = hcat([r.IC for r in estimators]...) - return Statistics.cov(X, dims=1, corrected=true) -end diff --git a/src/estimates.jl b/src/estimates.jl new file mode 100644 index 00000000..aa631284 --- /dev/null +++ b/src/estimates.jl @@ -0,0 +1,342 @@ + + +abstract type Estimate end + +string_repr(estimate::Estimate) = estimate + +Base.show(io::IO, ::MIME"text/plain", estimate::Estimate) = + println(io, string_repr(estimate)) + +##################################################################### +### MLConditionalDistribution ### +##################################################################### + +""" +Holds a Machine Learning estimate for a Conditional Distribution. +""" +struct MLConditionalDistribution <: Estimate + estimand::ConditionalDistribution + machine::MLJBase.Machine +end + +string_repr(estimate::MLConditionalDistribution) = string( + "P̂(", estimate.estimand.outcome, " | ", join(estimate.estimand.parents, ", "), + "), with model: ", + Base.typename(typeof(estimate.machine.model)).wrapper +) + +featurenames(estimate::MLConditionalDistribution) = featurenames(estimate.estimand) + +function MLJBase.predict(estimate::MLConditionalDistribution, dataset) + X = selectcols(dataset, featurenames(estimate)) + return predict(estimate.machine, X) +end + +function Distributions.estimate(factor::ConditionalDistribution, dataset, model, train_validation_indices::Nothing; + factors_cache=nothing, + verbosity=1) + relevant_dataset = TMLE.nomissing(dataset, TMLE.variables(factor)) + X = selectcols(relevant_dataset, TMLE.featurenames(factor)) + y = Tables.getcolumn(relevant_dataset, factor.outcome) + mach = machine(model, X, y) + fit!(mach, verbosity=verbosity-1) + return MLConditionalDistribution(factor, mach) +end + +##################################################################### +### SampleSplitMLConditionalDistribution ### +##################################################################### + +""" +Holds a Sample Split Machine Learning estimate for a Conditional Distribution. +""" +struct SampleSplitMLConditionalDistribution <: Estimate + estimand::ConditionalDistribution + train_validation_indices + machines +end + +string_repr(estimate::SampleSplitMLConditionalDistribution) = + string("P̂(", estimate.estimand.outcome, " | ", join(estimate.estimand.parents, ", "), + "), sample split with model: ", + Base.typename(typeof(first(estimate.machines).model)).wrapper +) + +function MLJBase.predict(estimate::SampleSplitMLConditionalDistribution, dataset) + X = selectcols(dataset, featurenames(estimate.estimand)) + ŷs = [] + for (fold, (_, validation_indices)) in enumerate(estimate.train_validation_indices) + Xval = selectrows(X, validation_indices) + push!(ŷs, predict(estimate.machines[fold], Xval)) + end + return vcat(ŷs...) +end + +function Distributions.estimate(factor::ConditionalDistribution, dataset, model, train_validation_indices; + factors_cache=nothing, + verbosity=1) + relevant_dataset = TMLE.selectcols(dataset, TMLE.variables(factor)) + nfolds = size(train_validation_indices, 1) + machines = Vector{Machine}(undef, nfolds) + features = TMLE.featurenames(factor) + for (index, (train_indices, _)) in enumerate(train_validation_indices) + train_dataset = selectrows(relevant_dataset, train_indices) + Xtrain = selectcols(train_dataset, features) + ytrain = Tables.getcolumn(train_dataset, factor.outcome) + mach = machine(model, Xtrain, ytrain) + fit!(mach, verbosity=verbosity-1) + machines[index] = mach + end + return SampleSplitMLConditionalDistribution(factor, train_validation_indices, machines) +end + +##################################################################### +### ConditionalDistributionEstimate ### +##################################################################### + +ConditionalDistributionEstimate = Union{MLConditionalDistribution, SampleSplitMLConditionalDistribution} + +function expected_value(estimate::ConditionalDistributionEstimate, dataset) + return expected_value(predict(estimate, dataset)) +end + +function likelihood(estimate::ConditionalDistributionEstimate, dataset) + ŷ = predict(estimate, dataset) + y = Tables.getcolumn(dataset, estimate.estimand.outcome) + return pdf.(ŷ, y) +end + + +##################################################################### +### MLCMRelevantFactors ### +##################################################################### + +""" +Holds a Sample Split Machine Learning set of estimates (outcome mean, propensity score) +for counterfactual mean based estimands' relevant factors. +""" +struct MLCMRelevantFactors <: Estimate + estimand::CMRelevantFactors + outcome_mean::ConditionalDistributionEstimate + propensity_score::Tuple{Vararg{ConditionalDistributionEstimate}} +end + +get_train_validation_indices(resampling::Nothing, factors, dataset) = nothing + +function get_train_validation_indices(resampling::ResamplingStrategy, factors, dataset) + relevant_columns = collect(TMLE.variables(factors)) + outcome_variable = factors.outcome_mean.outcome + feature_variables = filter(x -> x !== outcome_variable, relevant_columns) + return collect(MLJBase.train_test_pairs( + resampling, + 1:nrows(dataset), + selectcols(dataset, feature_variables), + Tables.getcolumn(dataset, outcome_variable) + )) +end + +function Distributions.estimate(relevant_factors::CMRelevantFactors, resampling, models, dataset; + factors_cache=nothing, + verbosity=1) + # Get train validation indices + train_validation_indices = get_train_validation_indices(resampling, relevant_factors, dataset) + # Fit propensity score + propensity_score_estimate = Tuple( + estimate( + factor, + dataset, + models[factor.outcome], + train_validation_indices; + verbosity=verbosity, + factors_cache=factors_cache) + for factor in relevant_factors.propensity_score + ) + # Fit outcome mean + outcome_mean = relevant_factors.outcome_mean + model = models[outcome_mean.outcome] + outcome_mean_estimate = TMLE.estimate( + outcome_mean, + dataset, + model, + train_validation_indices; + verbosity=verbosity, + factors_cache=factors_cache) + + return MLCMRelevantFactors(relevant_factors, outcome_mean_estimate, propensity_score_estimate) +end + +string_repr(estimate::MLCMRelevantFactors) = string( + "Composite Factor Estimate: \n", + "-------------------------\n- ", + string_repr(estimate.outcome_mean),"\n- ", + join((string_repr(f) for f in estimate.propensity_score), "\n- ") +) + + +##################################################################### +### One Dimensional Estimates ### +##################################################################### + + +struct TMLEstimate{T<:AbstractFloat} <: Estimate + Ψ̂::T + IC::Vector{T} +end + +struct OSEstimate{T<:AbstractFloat} <: Estimate + Ψ̂::T + IC::Vector{T} +end + +const EICEstimate = Union{TMLEstimate, OSEstimate} + +function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate) + testresult = OneSampleTTest(est) + data = [estimate(est) confint(testresult) pvalue(testresult);] + pretty_table(io, data;header=["Estimate", "95% Confidence Interval", "P-value"]) +end + +struct ComposedEstimate{T<:AbstractFloat} <: Estimate + Ψ̂::Array{T} + σ̂::Matrix{T} + n::Int +end + +function Base.show(io::IO, ::MIME"text/plain", est::ComposedEstimate) + if length(est.σ̂) !== 1 + println(io, string("Estimate: ", estimate(est), "\nVariance: \n", var(est))) + else + testresult = OneSampleTTest(est) + data = [estimate(est) confint(testresult) pvalue(testresult);] + headers = ["Estimate", "95% Confidence Interval", "P-value"] + pretty_table(io, data;header=headers) + end +end + +""" + Distributions.estimate(r::EICEstimate) + +Retrieves the final estimate: after the TMLE step. +""" +Distributions.estimate(est::EICEstimate) = est.Ψ̂ + +""" + Distributions.estimate(r::ComposedEstimate) + +Retrieves the final estimate: after the TMLE step. +""" +Distributions.estimate(est::ComposedEstimate) = + length(est.Ψ̂) == 1 ? est.Ψ̂[1] : est.Ψ̂ + +""" + var(r::EICEstimate) + +Computes the estimated variance associated with the estimate. +""" +Statistics.var(est::EICEstimate) = + var(est.IC)/size(est.IC, 1) + +""" + var(r::ComposedEstimate) + +Computes the estimated variance associated with the estimate. +""" +Statistics.var(est::ComposedEstimate) = + length(est.σ̂) == 1 ? est.σ̂[1] / est.n : est.σ̂ ./ est.n + + +""" + OneSampleZTest(r::EICEstimate, Ψ₀=0) + +Performs a Z test on the EICEstimate. +""" +HypothesisTests.OneSampleZTest(est::EICEstimate, Ψ₀=0) = + OneSampleZTest(estimate(est), std(est.IC), size(est.IC, 1), Ψ₀) + +""" + OneSampleTTest(r::EICEstimate, Ψ₀=0) + +Performs a T test on the EICEstimate. +""" +HypothesisTests.OneSampleTTest(est::EICEstimate, Ψ₀=0) = + OneSampleTTest(estimate(est), std(est.IC), size(est.IC, 1), Ψ₀) + +""" + OneSampleTTest(r::ComposedEstimate, Ψ₀=0) + +Performs a T test on the ComposedEstimate. +""" +function HypothesisTests.OneSampleTTest(est::ComposedEstimate, Ψ₀=0) + @assert length(est.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." + return OneSampleTTest(estimate(est), sqrt(est.σ̂[1]), est.n, Ψ₀) +end + +""" + OneSampleZTest(r::ComposedEstimate, Ψ₀=0) + +Performs a T test on the ComposedEstimate. +""" +function HypothesisTests.OneSampleZTest(est::ComposedEstimate, Ψ₀=0) + @assert length(est.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." + return OneSampleZTest(estimate(est), sqrt(est.σ̂[1]), est.n, Ψ₀) +end + +""" + compose(f, estimation_results::Vararg{EICEstimate, N}) where N + +Provides an estimator of f(estimation_results...). + +# Mathematical details + +The following is a summary from `Asymptotic Statistics`, A. W. van der Vaart. + +Consider k TMLEs computed from a dataset of size n and embodied by Tₙ = (T₁,ₙ, ..., Tₖ,ₙ). +Since each of them is asymptotically normal, the multivariate CLT provides the joint +distribution: + + √n(Tₙ - Ψ₀) ↝ N(0, Σ), + +where Σ is the covariance matrix of the TMLEs influence curves. + +Let f:ℜᵏ→ℜᵐ, be a differentiable map at Ψ₀. Then, the delta method provides the +limiting distribution of √n(f(Tₙ) - f(Ψ₀)). Because Tₙ is normal, the result is: + + √n(f(Tₙ) - f(Ψ₀)) ↝ N(0, ∇f(Ψ₀) ̇Σ ̇(∇f(Ψ₀))ᵀ), + +where ∇f(Ψ₀):ℜᵏ→ℜᵐ is a linear map such that by abusing notations and identifying the +function with the multiplication matrix: ∇f(Ψ₀):h ↦ ∇f(Ψ₀) ̇h. And the matrix ∇f(Ψ₀) is +the jacobian of f at Ψ₀. + +Hence, the only thing we need to do is: +- Compute the covariance matrix Σ +- Compute the jacobian ∇f, which can be done using Julia's automatic differentiation facilities. +- The final estimator is normal with mean f₀=f(Ψ₀) and variance σ₀=∇f(Ψ₀) ̇Σ ̇(∇f(Ψ₀))ᵀ + +# Arguments + +- f: An array-input differentiable map. +- estimation_results: 1 or more `EICEstimate` structs. + +# Examples + +Assuming `res₁` and `res₂` are TMLEs: + +```julia +f(x, y) = [x^2 - y, y - 3x] +compose(f, res₁, res₂) +``` +""" +function compose(f, estimators::Vararg{EICEstimate, N}; backend=AD.ZygoteBackend()) where N + Σ = cov(estimators...) + estimates = [estimate(r) for r in estimators] + f₀, Js = AD.value_and_jacobian(backend, f, estimates...) + J = hcat(Js...) + n = size(first(estimators).IC, 1) + σ₀ = J*Σ*J' + return ComposedEstimate(collect(f₀), σ₀, n) +end + +function Statistics.cov(estimators::Vararg{EICEstimate, N}) where N + X = hcat([r.IC for r in estimators]...) + return Statistics.cov(X, dims=1, corrected=true) +end diff --git a/src/scm.jl b/src/scm.jl index 9052b1ae..14991ca5 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -101,32 +101,44 @@ UsingNaturalDistributionLog(scm, outcome, conditioning_set) = string( cond_dist_string(outcome, parents(scm, outcome)) ) -function get_or_set_conditional_distribution_from_natural!(scm::SCM, outcome::Symbol, conditioning_set::Set{Symbol}; verbosity=1) +function get_or_set_conditional_distribution_from_natural!(scm::SCM, outcome::Symbol, conditioning_set::Set{Symbol}; resampling=nothing, verbosity=1) try - return get_conditional_distribution(scm, outcome, conditioning_set) + # A factor has already been setup + factor = get_conditional_distribution(scm, outcome, conditioning_set) + # If the resampling strategies also match then we can reuse the already set factor + new_factor = ConditionalDistribution(outcome, conditioning_set, factor.model, resampling=resampling) + if match(factor, new_factor) + return factor + else + setfactor!(scm, new_factor) + return new_factor + end catch KeyError + # A factor has not already been setup try + # A natural factor can be used as a template template_factor = get_conditional_distribution(scm, outcome) - set_conditional_distribution!(scm, outcome, conditioning_set, template_factor.model) - factor = get_conditional_distribution(scm, outcome, conditioning_set) + new_factor = set_conditional_distribution!(scm, outcome, conditioning_set, template_factor.model, resampling) verbosity > 0 && @info(UsingNaturalDistributionLog(scm, outcome, conditioning_set)) - return factor + return new_factor catch KeyError throw(NoAvailableConditionalDistributionError(scm, outcome, conditioning_set)) end end end -set_conditional_distribution!(scm::SCM, outcome::Symbol, model) = - set_conditional_distribution!(scm, outcome, parents(scm, outcome), model) +set_conditional_distribution!(scm::SCM, outcome::Symbol, model, resampling) = + set_conditional_distribution!(scm, outcome, parents(scm, outcome), model, resampling) -function set_conditional_distribution!(scm::SCM, outcome::Symbol, parents, model) - factor = ConditionalDistribution(outcome, parents, model) - setfactor!(scm, factor) +function set_conditional_distribution!(scm::SCM, outcome::Symbol, parents, model, resampling) + factor = ConditionalDistribution(outcome, parents, model, resampling=resampling) + return setfactor!(scm, factor) end -setfactor!(scm::SCM, factor::DistributionFactor) = +function setfactor!(scm::SCM, factor) scm.factors[key(factor)] = factor + return factor +end function string_repr(scm::SCM) scm_string = """ @@ -192,11 +204,11 @@ function StaticConfoundedModel( Teq = SE(treatment, vcat(confounders)) outcome_factor = ConditionalDistribution(Yeq.outcome, Yeq.parents, outcome_model) treatment_factor = ConditionalDistribution(Teq.outcome, Teq.parents, treatment_model) - factors = Dict( - key(outcome_factor) => outcome_factor, - key(treatment_factor) => treatment_factor - ) - return StructuralCausalModel(Yeq, Teq; factors) + # factors = Dict( + # key(outcome_factor) => outcome_factor, + # key(treatment_factor) => treatment_factor + # ) + return StructuralCausalModel(Yeq, Teq) end """ @@ -228,18 +240,18 @@ function StaticConfoundedModel( ) Yequations = (SE(outcome, combine_outcome_parents(treatments, confounders, covariates)) for outcome in outcomes) Tequations = (SE(treatment, vcat(confounders)) for treatment in treatments) - factors = Dict( - (outcome(Yeq) => Dict(parents(Yeq) => outcome_model) for Yeq in Yequations)..., - (outcome(Teq) => Dict(parents(Teq) => treatment_model) for Teq in Tequations)... - ) - factors = Dict() - for eq in Yequations - factor = ConditionalDistribution(eq.outcome, eq.parents, outcome_model) - factors[key(factor)] = factor - end - for eq in Tequations - factor = ConditionalDistribution(eq.outcome, eq.parents, treatment_model) - factors[key(factor)] = factor - end - return SCM(Yequations..., Tequations...; factors=factors) + # factors = Dict( + # (outcome(Yeq) => Dict(parents(Yeq) => outcome_model) for Yeq in Yequations)..., + # (outcome(Teq) => Dict(parents(Teq) => treatment_model) for Teq in Tequations)... + # ) + # factors = Dict() + # for eq in Yequations + # factor = ConditionalDistribution(eq.outcome, eq.parents, outcome_model) + # factors[key(factor)] = factor + # end + # for eq in Tequations + # factor = ConditionalDistribution(eq.outcome, eq.parents, treatment_model) + # factors[key(factor)] = factor + # end + return SCM(Yequations..., Tequations...) end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 1457309f..ee95122f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -41,7 +41,7 @@ expected_value(ŷ::UnivariateFiniteVector{Multiclass{2}}) = pdf.(ŷ, levels(firs expected_value(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = mean.(ŷ) expected_value(ŷ::AbstractVector{<:Real}) = ŷ -training_expected_value(Q::Machine) = expected_value(predict(Q)) +training_expected_value(Q::Machine, dataset) = expected_value(predict(Q, dataset)) function counterfactualTreatment(vals, T) diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index bcc535ba..e1998903 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -6,6 +6,7 @@ using CSV using DataFrames using CategoricalArrays using MLJGLMInterface +using MLJBase @testset "Test ATE on perinatal dataset." begin # This is a non-regression test which was checked against the R tmle3 package @@ -24,16 +25,41 @@ using MLJGLMInterface outcome_model = TreatmentTransformer() |> LinearBinaryClassifier(), treatment_model = LinearBinaryClassifier() ) + models = ( + haz01 = with_encoder(LinearBinaryClassifier()), + parity01 = LinearBinaryClassifier() + ) - tmle_result, targeted_factors = tmle!(Ψ, dataset; ps_lowerbound=0.025, verbosity=0) # TMLE - @test estimate(tmle_result) ≈ -0.185533 atol = 1e-6 - @test naive_plugin_estimate!(Ψ, dataset) ≈ -0.150078 atol = 1e-6 + ps_lowerbound = 0.025 # Cutoff hardcoded in tmle3 + resampling = nothing # Vanilla TMLE + weighted_fluctuation = false # Unweighted fluctuation + verbosity = 1 # No logs + adjustment_method = BackdoorAdjustment() + factors_cache = nothing + tmle_result, targeted_factors = tmle!(Ψ, models, dataset; + resampling=resampling, + weighted_fluctuation=weighted_fluctuation, + ps_lowerbound=ps_lowerbound, + adjustment_method=adjustment_method, + verbosity=verbosity) + tmle_result + estimate(tmle_result) ≈ -0.185533 atol = 1e-6 l, u = confint(OneSampleTTest(tmle_result)) @test l ≈ -0.279246 atol = 1e-6 @test u ≈ -0.091821 atol = 1e-6 @test OneSampleZTest(tmle_result) isa OneSampleZTest + # CV-TMLE + + tmle_result, targeted_factors = tmle!(Ψ, dataset; + resampling=resampling, + adjustment_method=adjustment_method, + ps_lowerbound=ps_lowerbound, + verbosity=verbosity) + # Naive + @test naive_plugin_estimate!(Ψ, dataset) ≈ -0.150078 atol = 1e-6 + end end From f3cf47d7b493d30d37de36d9ab1f01738fad9183 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 6 Sep 2023 14:59:09 +0100 Subject: [PATCH 065/151] add more content --- sandbox.R | 134 +++++------------- src/counterfactual_mean_based/estimands.jl | 9 +- src/counterfactual_mean_based/estimators.jl | 56 ++++---- src/counterfactual_mean_based/gradient.jl | 27 +++- .../offset_and_covariate.jl | 5 +- src/scm.jl | 44 +----- .../non_regression_test.jl | 30 ++-- 7 files changed, 112 insertions(+), 193 deletions(-) diff --git a/sandbox.R b/sandbox.R index e6917201..01387d93 100644 --- a/sandbox.R +++ b/sandbox.R @@ -1,121 +1,53 @@ library(data.table) library(tmle3) library(sl3) -data <- fread( - paste0( - "https://raw.githubusercontent.com/tlverse/tlverse-data/master/", - "wash-benefits/washb_data.csv" - ), - stringsAsFactors = TRUE -) + +data = read.csv("/Users/olivierlabayle/Dev/TARGENE/TMLE.jl/test/data/perinatal.csv") node_list <- list( W = c( - "month", "aged", "sex", "momage", "momedu", - "momheight", "hfiacat", "Nlt18", "Ncomp", "watmin", - "elec", "floor", "walls", "roof", "asset_wardrobe", - "asset_table", "asset_chair", "asset_khat", - "asset_chouki", "asset_tv", "asset_refrig", - "asset_bike", "asset_moto", "asset_sewmach", - "asset_mobile" + "apgar1", "apgar5", "gagebrth", "mage", "meducyrs", "sexn" ), - A = "tr", - Y = "whz" + A = "parity01", + Y = "haz01" ) -processed <- process_missing(data, node_list) -data <- processed$data -node_list <- processed$node_list +glm = Lrnr_glm$new() +lrn_mean = Lrnr_mean$new() +sl <- Lrnr_sl$new(learners = Stack$new(glm, lrn_mean), metalearner = Lrnr_nnls$new()) + +learner_list <- list(A = glm, Y = glm) +# learner_list = list(A=sl, Y = sl) ate_spec <- tmle_ATE( - treatment_level = "Nutrition + WSH", - control_level = "Control" + treatment_level = 1, + control_level = 0 ) -tmle_spec = ate_spec - -# choose base learners -lrnr_mean <- make_learner(Lrnr_mean) -lrnr_rf <- make_learner(Lrnr_ranger) -# define metalearners appropriate to data types -ls_metalearner <- make_learner(Lrnr_nnls) -mn_metalearner <- make_learner( - Lrnr_solnp, metalearner_linear_multinomial, - loss_loglik_multinomial -) -sl_Y <- Lrnr_sl$new( - learners = list(lrnr_mean, lrnr_rf), - metalearner = ls_metalearner +tmle_task <- ate_spec$make_tmle_task(data, node_list) +initial_likelihood <- ate_spec$make_initial_likelihood( + tmle_task, + learner_list ) -sl_A <- Lrnr_sl$new( - learners = list(lrnr_mean, lrnr_rf), - metalearner = mn_metalearner -) -learner_list <- list(A = sl_A, Y = sl_Y) - -tmle_task <- tmle_spec$make_tmle_task(data, node_list) - -initial_likelihood <- tmle_spec$make_initial_likelihood(tmle_task, learner_list) - -updater <- tmle_spec$make_updater() -targeted_likelihood <- tmle_spec$make_targeted_likelihood(initial_likelihood, updater) - -tmle_params <- tmle_spec$make_params(tmle_task, targeted_likelihood) -updater$tmle_params <- tmle_params -fit <- fit_tmle3(tmle_task, targeted_likelihood, tmle_params, updater) +targeted_likelihood_cv <- Targeted_Likelihood$new(initial_likelihood) -update_fold <- updater$update_fold -maxit <- 100 +targeted_likelihood_no_cv <- + Targeted_Likelihood$new(initial_likelihood, + updater = list(cvtmle = FALSE) + ) -# seed current estimates -current_estimates <- lapply(updater$tmle_params, function(tmle_param) { - tmle_param$estimates(tmle_task, update_fold) -}) +tmle_params_cv <- ate_spec$make_params(tmle_task, targeted_likelihood_cv) +tmle_params_no_cv <- ate_spec$make_params(tmle_task, targeted_likelihood_no_cv) -likelihood = initial_likelihood -likelihood_values <- likelihood$cache$get_values(A_factor, tmle_task, fold_number) -likelihood_values <- A_factor$get_likelihood(tmle_task, fold_number) - -update_node <- updater$update_nodes[[1]] -submodel_data <- updater$generate_submodel_data( - likelihood, tmle_task, - fold_number, update_node, - drop_censored = TRUE - ) - -for (steps in seq_len(maxit)) { - updater$update_step(likelihood, tmle_task, update_fold) - - # update estimates based on updated likelihood - private$.current_estimates <- lapply(self$tmle_params, function(tmle_param) { - tmle_param$estimates(tmle_task, update_fold) - }) - - if (self$check_convergence(tmle_task, update_fold)) { - break - } - - if (self$use_best) { - self$update_best(likelihood) - } -} - -if (self$use_best) { - self$update_best(likelihood) - likelihood$cache$set_best() -} - -private$.steps <- self$updater$steps - - -estimates <- lapply( - self$tmle_params, - function(tmle_param) { - tmle_param$estimates(self$tmle_task, self$updater$update_fold) - } +tmle_no_cv <- fit_tmle3( + tmle_task, targeted_likelihood_no_cv, tmle_params_no_cv, + targeted_likelihood_no_cv$updater ) +tmle_no_cv -private$.estimates <- estimates -private$.ED <- ED_from_estimates(estimates) - +tmle_cv <- fit_tmle3( + tmle_task, targeted_likelihood_cv, tmle_params_cv, + targeted_likelihood_cv$updater +) +tmle_cv \ No newline at end of file diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index b7097af5..b66afb62 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -165,16 +165,13 @@ for typename in CMCompositeTypenames outcome::Symbol, treatment::NamedTuple, confounders::Union{Symbol, AbstractVector{Symbol}}, - covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, - outcome_model = with_encoder(LinearRegressor()), - treatment_model = LinearBinaryClassifier()) + covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing + ) scm = StaticConfoundedModel( [outcome], collect(keys(treatment)), confounders; - covariates=covariates, - outcome_model=outcome_model, - treatment_model=treatment_model + covariates=covariates ) return $(typename)(scm; outcome=outcome, treatment=treatment) end diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 135e0504..2a181bc4 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -12,7 +12,7 @@ function tmle!(Ψ::CMCompositeEstimand, models, dataset; ) # Check the estimand against the dataset TMLE.check_treatment_levels(Ψ, dataset) - # Initial fit of the SCM's equations + # Initial fit of the SCM's relevant factors verbosity >= 1 && @info "Fitting the required equations..." relevant_factors = TMLE.get_relevant_factors(Ψ; adjustment_method=adjustment_method) nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(relevant_factors)) @@ -38,46 +38,48 @@ function tmle!(Ψ::CMCompositeEstimand, models, dataset; return TMLEstimate(Ψ̂, IC), targeted_factors_estimate end -function ose!(Ψ::CMCompositeEstimand, dataset; +function ose!(Ψ::CMCompositeEstimand, models, dataset; adjustment_method=BackdoorAdjustment(), verbosity=1, - force=false, - ps_lowerbound=1e-8) + force=false, + resampling=nothing, + ps_lowerbound=1e-8, + factors_cache=nothing) # Check the estimand against the dataset - check_treatment_levels(Ψ, dataset) - # Initial fit of the SCM's equations + TMLE.check_treatment_levels(Ψ, dataset) + # Initial fit of the SCM's relevant factors verbosity >= 1 && @info "Fitting the required equations..." - fit!(Ψ, dataset; - adjustment_method=adjustment_method, - verbosity=verbosity, - force=force + relevant_factors = TMLE.get_relevant_factors(Ψ; adjustment_method=adjustment_method) + nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(relevant_factors)) + initial_factors_dataset = TMLE.choose_initial_dataset(dataset, nomissing_dataset, resampling) + initial_factors_estimate = TMLE.estimate(relevant_factors, resampling, models, initial_factors_dataset; + factors_cache=factors_cache, + verbosity=verbosity ) # Get propensity score truncation threshold - ps_lowerbound = ps_lower_bound(Ψ, ps_lowerbound) - # Retrieve initial fit - Q = get_outcome_model(Ψ) + n = nrows(nomissing_dataset) + ps_lowerbound = TMLE.ps_lower_bound(n, ps_lowerbound) + # Gradient and estimate - IC, Ψ̂ = gradient_and_estimate(Ψ, Q; ps_lowerbound=ps_lowerbound) + IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) verbosity >= 1 && @info "Done." - return OSEstimate(Ψ̂ + mean(IC), IC), Q + return OSEstimate(Ψ̂ + mean(IC), IC), initial_factors_estimate end -function naive_plugin_estimate!(Ψ::CMCompositeEstimand, dataset; +function naive_plugin_estimate!(Ψ::CMCompositeEstimand, models, dataset; adjustment_method=BackdoorAdjustment(), verbosity=1, - force=false) + force=false, + factors_cache=nothing) # Check the estimand against the dataset - check_treatment_levels(Ψ, dataset) - # Initial fit of the SCM's equations + TMLE.check_treatment_levels(Ψ, dataset) + # Initial fit of the SCM's relevant factors verbosity >= 1 && @info "Fitting the required equations..." - Q = get_or_set_conditional_distribution_from_natural!( - Ψ.scm, outcome(Ψ), - outcome_parents(adjustment_method, Ψ); + relevant_factors = TMLE.get_relevant_factors(Ψ; adjustment_method=adjustment_method) + nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(relevant_factors)) + factors_estimate = TMLE.estimate(relevant_factors, nothing, models, dataset; + factors_cache=factors_cache, verbosity=verbosity ) - fit!(Q, dataset; - verbosity=verbosity, - force=force - ) - return mean(counterfactual_aggregate(Ψ, Q)) + return mean(counterfactual_aggregate(Ψ, factors_estimate.outcome_mean, nomissing_dataset)) end \ No newline at end of file diff --git a/src/counterfactual_mean_based/gradient.jl b/src/counterfactual_mean_based/gradient.jl index 88bf5766..9d1fbac9 100644 --- a/src/counterfactual_mean_based/gradient.jl +++ b/src/counterfactual_mean_based/gradient.jl @@ -24,6 +24,12 @@ function counterfactual_aggregate(Ψ::CMCompositeEstimand, Q, dataset) return ctf_agg end +compute_estimate(ctf_aggregate, ::Nothing) = mean(ctf_aggregate) + +compute_estimate(ctf_aggregate, train_validation_indices) = + mean(compute_estimate(ctf_aggregate[val_indices], nothing) for (_, val_indices) in train_validation_indices) + + """ ∇W(ctf_agg, Ψ̂) @@ -39,10 +45,10 @@ end This part of the gradient is evaluated on the original dataset. All quantities have been precomputed and cached. """ function ∇YX(Ψ::CMCompositeEstimand, Q, G, dataset; ps_lowerbound=1e-8) - Qmach = Q.machine - H = weighted_covariate(Qmach, G, Ψ, dataset; ps_lowerbound=ps_lowerbound) + # Maybe can cache some results (H and E[Y|X]) to improve perf here + H, weights = clever_covariate_and_weights(Ψ, G, dataset; ps_lowerbound=ps_lowerbound) y = float(Tables.getcolumn(dataset, Q.estimand.outcome)) - gradient_Y_X_fluct = H .* (y .- training_expected_value(Qmach, dataset)) + gradient_Y_X_fluct = H .* weights .* (y .- expected_value(Q, dataset)) return gradient_Y_X_fluct end @@ -50,8 +56,15 @@ end function gradient_and_estimate(Ψ::CMCompositeEstimand, factors, dataset; ps_lowerbound=1e-8) Q = factors.outcome_mean G = factors.propensity_score - ctf_agg = counterfactual_aggregate(Ψ, Q, dataset) - Ψ̂ = mean(ctf_agg) - IC = ∇YX(Ψ, Q, G, dataset; ps_lowerbound = ps_lowerbound) .+ ∇W(ctf_agg, Ψ̂) + ctf_agg = TMLE.counterfactual_aggregate(Ψ, Q, dataset) + Ψ̂ = TMLE.compute_estimate(ctf_agg, TMLE.train_validation_indices_from_factors(factors)) + IC = TMLE.∇YX(Ψ, Q, G, dataset; ps_lowerbound = ps_lowerbound) .+ TMLE.∇W(ctf_agg, Ψ̂) return IC, Ψ̂ -end \ No newline at end of file +end + + +train_validation_indices_from_ps(::MLConditionalDistribution) = nothing +train_validation_indices_from_ps(factor::SampleSplitMLConditionalDistribution) = factor.train_validation_indices + +train_validation_indices_from_factors(factors) = + train_validation_indices_from_ps(first(factors.propensity_score)) \ No newline at end of file diff --git a/src/counterfactual_mean_based/offset_and_covariate.jl b/src/counterfactual_mean_based/offset_and_covariate.jl index 67dfd750..43203ed3 100644 --- a/src/counterfactual_mean_based/offset_and_covariate.jl +++ b/src/counterfactual_mean_based/offset_and_covariate.jl @@ -11,9 +11,6 @@ data_adaptive_ps_lower_bound(n::Int; max_lb=0.1) = ps_lower_bound(n::Int, lower_bound::Nothing; max_lb=0.1) = data_adaptive_ps_lower_bound(n; max_lb=max_lb) ps_lower_bound(n::Int, lower_bound; max_lb=0.1) = min(max_lb, lower_bound) -treatments(dataset, G) = selectcols(dataset, keys(G)) -confounders(dataset, G) = (;(key => selectcols(dataset, keys(cd.machine.data[1])) for (key, cd) in zip(keys(G), G))...) - function truncate!(v::AbstractVector, ps_lowerbound::AbstractFloat) for i in eachindex(v) @@ -58,7 +55,7 @@ where SpecialIndicator(t) is defined in `indicator_fns`. """ function clever_covariate_and_weights(Ψ::CMCompositeEstimand, Gs, dataset; ps_lowerbound=1e-8, weighted_fluctuation=false) # Compute the indicator values - T = treatments(dataset, Gs) + T = TMLE.selectcols(dataset, (p.estimand.outcome for p in Gs)) indic_vals = indicator_values(indicator_fns(Ψ), T) weights = balancing_weights(Gs, dataset; ps_lowerbound=ps_lowerbound) if weighted_fluctuation diff --git a/src/scm.jl b/src/scm.jl index 14991ca5..c295ae65 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -66,12 +66,13 @@ scm = SCM( struct StructuralCausalModel equations::Dict{Symbol, StructuralEquation} factors::Dict + StructuralCausalModel(equations) = new(equations, Dict()) end const SCM = StructuralCausalModel -SCM(;equations=Dict{Symbol, SE}(), factors=Dict()) = SCM(equations, factors) -SCM(equations::Vararg{SE}; factors=Dict()) = SCM(Dict(outcome(eq) => eq for eq in equations), factors) +SCM(;equations=Dict{Symbol, SE}()) = SCM(equations) +SCM(equations::Vararg{SE}) = SCM(Dict(outcome(eq) => eq for eq in equations)) equations(scm::SCM) = scm.equations factors(scm::SCM) = scm.factors @@ -177,37 +178,24 @@ end ### StaticConfoundedModel ### ##################################################################### -combine_outcome_parents(treatment, confounders, covariates::Nothing) = vcat(treatment, confounders) +combine_outcome_parents(treatment, confounders, ::Nothing) = vcat(treatment, confounders) combine_outcome_parents(treatment, confounders, covariates) = vcat(treatment, confounders, covariates) """ StaticConfoundedModel( outcome::Symbol, treatment::Symbol, confounders::Union{Symbol, AbstractVector{Symbol}}; - covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, - outcome_model = TreatmentTransformer() |> LinearRegressor(), - treatment_model = LinearBinaryClassifier() + covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing ) Defines a classic Structural Causal Model with one outcome, one treatment, a set of confounding variables and optional covariates influencing the outcome only. - -The `outcome_model` and `treatment_model` define the relationship between -the outcome (resp. treatment) and their ancestors. """ function StaticConfoundedModel( outcome, treatment, confounders; - covariates = nothing, - outcome_model = TreatmentTransformer() |> LinearRegressor(), - treatment_model = LinearBinaryClassifier(), + covariates = nothing ) Yeq = SE(outcome, combine_outcome_parents(treatment, confounders, covariates)) Teq = SE(treatment, vcat(confounders)) - outcome_factor = ConditionalDistribution(Yeq.outcome, Yeq.parents, outcome_model) - treatment_factor = ConditionalDistribution(Teq.outcome, Teq.parents, treatment_model) - # factors = Dict( - # key(outcome_factor) => outcome_factor, - # key(treatment_factor) => treatment_factor - # ) return StructuralCausalModel(Yeq, Teq) end @@ -226,32 +214,14 @@ a set of confounding variables and optional covariates influencing the outcomes All treatments are assumed to be direct parents of all outcomes. The confounding variables are shared for all treatments. - -The `outcome_model` and `treatment_model` define the relationships between -the outcomes (resp. treatments) and their ancestors. """ function StaticConfoundedModel( outcomes::AbstractVector, treatments::AbstractVector, confounders; - covariates = nothing, - outcome_model = TreatmentTransformer() |> LinearRegressor(), - treatment_model = LinearBinaryClassifier() + covariates = nothing ) Yequations = (SE(outcome, combine_outcome_parents(treatments, confounders, covariates)) for outcome in outcomes) Tequations = (SE(treatment, vcat(confounders)) for treatment in treatments) - # factors = Dict( - # (outcome(Yeq) => Dict(parents(Yeq) => outcome_model) for Yeq in Yequations)..., - # (outcome(Teq) => Dict(parents(Teq) => treatment_model) for Teq in Tequations)... - # ) - # factors = Dict() - # for eq in Yequations - # factor = ConditionalDistribution(eq.outcome, eq.parents, outcome_model) - # factors[key(factor)] = factor - # end - # for eq in Tequations - # factor = ConditionalDistribution(eq.outcome, eq.parents, treatment_model) - # factors[key(factor)] = factor - # end return SCM(Yequations..., Tequations...) end \ No newline at end of file diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index e1998903..c6db7eda 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -13,7 +13,7 @@ using MLJBase dataset = CSV.read(joinpath("data", "perinatal.csv"), DataFrame, missingstring=["", "NA"]) confounders = [:apgar1, :apgar5, :gagebrth, :mage, :meducyrs, :sexn] dataset.haz01 = categorical(dataset.haz01) - dataset.parity01 = categorical(dataset.parity01) + dataset.parity01 = categorical(dataset.parity01, ordered=true) for col in confounders dataset[!, col] = float(dataset[!, col]) end @@ -21,16 +21,13 @@ using MLJBase Ψ = ATE( outcome=:haz01, treatment=(parity01=(case=1, control=0),), - confounders=confounders, - outcome_model = TreatmentTransformer() |> LinearBinaryClassifier(), - treatment_model = LinearBinaryClassifier() + confounders=confounders ) models = ( haz01 = with_encoder(LinearBinaryClassifier()), parity01 = LinearBinaryClassifier() ) - # TMLE ps_lowerbound = 0.025 # Cutoff hardcoded in tmle3 resampling = nothing # Vanilla TMLE weighted_fluctuation = false # Unweighted fluctuation @@ -44,21 +41,32 @@ using MLJBase adjustment_method=adjustment_method, verbosity=verbosity) tmle_result - estimate(tmle_result) ≈ -0.185533 atol = 1e-6 + @test estimate(tmle_result) ≈ -0.185533 atol = 1e-6 l, u = confint(OneSampleTTest(tmle_result)) @test l ≈ -0.279246 atol = 1e-6 @test u ≈ -0.091821 atol = 1e-6 @test OneSampleZTest(tmle_result) isa OneSampleZTest - # CV-TMLE - - tmle_result, targeted_factors = tmle!(Ψ, dataset; - resampling=resampling, + # OSE + ose_result, targeted_factors = ose!(Ψ, models, dataset; + resampling=resampling, + ps_lowerbound=ps_lowerbound, adjustment_method=adjustment_method, + verbosity=verbosity) + ose_result + + # CV-TMLE + resampling = StratifiedCV(nfolds=10) + cv_tmle_result, targeted_factors = tmle!(Ψ, models, dataset; + resampling=resampling, + weighted_fluctuation=weighted_fluctuation, ps_lowerbound=ps_lowerbound, + adjustment_method=adjustment_method, verbosity=verbosity) + cv_tmle_result + # Naive - @test naive_plugin_estimate!(Ψ, dataset) ≈ -0.150078 atol = 1e-6 + @test naive_plugin_estimate!(Ψ, models, dataset) ≈ -0.150078 atol = 1e-6 end From d13962d04a42d2d0dd98e914611403cae1bed2f8 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 6 Sep 2023 18:28:52 +0100 Subject: [PATCH 066/151] wip --- src/TMLE.jl | 2 +- src/counterfactual_mean_based/estimators.jl | 2 +- src/estimands.jl | 14 +++++ src/estimates.jl | 1 - test/estimands.jl | 38 ++++++++++++ test/estimates.jl | 65 +++++++++++++++++++++ 6 files changed, 119 insertions(+), 3 deletions(-) create mode 100644 test/estimands.jl create mode 100644 test/estimates.jl diff --git a/src/TMLE.jl b/src/TMLE.jl index a9a745ca..3b7a821f 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -27,7 +27,7 @@ export SE, StructuralEquation export StructuralCausalModel, SCM, StaticConfoundedModel export setequation!, getequation, get_conditional_distribution, set_conditional_distribution! export setmodel!, equations, reset!, parents -export ConditionalDistribution +export ConditionalDistribution, ExpectedValue export CounterfactualMean, CM export AverageTreatmentEffect, ATE export InteractionAverageTreatmentEffect, IATE diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 2a181bc4..7417228b 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -18,7 +18,7 @@ function tmle!(Ψ::CMCompositeEstimand, models, dataset; nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(relevant_factors)) initial_factors_dataset = TMLE.choose_initial_dataset(dataset, nomissing_dataset, resampling) initial_factors_estimate = TMLE.estimate(relevant_factors, resampling, models, initial_factors_dataset; - factors_cache=factors_cache, + factors_cache=factors_cache, verbosity=verbosity ) # Get propensity score truncation threshold diff --git a/src/estimands.jl b/src/estimands.jl index 05165270..16f5028e 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -136,6 +136,9 @@ struct CMRelevantFactors <: Estimand propensity_score::Tuple{Vararg{ConditionalDistribution}} end +CMRelevantFactors(scm, outcome_mean, propensity_score::ConditionalDistribution) = + CMRelevantFactors(scm, outcome_mean, (propensity_score,)) + string_repr(estimand::CMRelevantFactors) = string("Composite Factor: \n", "----------------\n- ", @@ -144,3 +147,14 @@ string_repr(estimand::CMRelevantFactors) = variables(estimand::CMRelevantFactors) = union(variables(estimand.outcome_mean), (variables(est) for est in estimand.propensity_score)...) + +""" + estimand_key(estimand::CMRelevantFactors) + +The key combines the outcome_mean's key anc the propensity scores' keys. +Because the order of treatment does not matter, we order them by treatment. +""" +estimand_key(estimand::CMRelevantFactors) = ( + estimand_key(estimand.outcome_mean), + sort((estimand_key(ps) for ps in estimand.propensity_score), by=x->x[1])... + ) \ No newline at end of file diff --git a/src/estimates.jl b/src/estimates.jl index aa631284..1c4e7115 100644 --- a/src/estimates.jl +++ b/src/estimates.jl @@ -106,7 +106,6 @@ function likelihood(estimate::ConditionalDistributionEstimate, dataset) return pdf.(ŷ, y) end - ##################################################################### ### MLCMRelevantFactors ### ##################################################################### diff --git a/test/estimands.jl b/test/estimands.jl new file mode 100644 index 00000000..73eb32ab --- /dev/null +++ b/test/estimands.jl @@ -0,0 +1,38 @@ +module TestEstimands + +using TMLE +using Test + +scm = SCM( + SE(:Y, [:T, :W]), + SE(:T₁, [:W]), + SE(:T₂, [:W]) + ) +Q̄₀ = TMLE.ExpectedValue(scm, :Y, [:T₁, :T₂, :W]) +G₀₁ = TMLE.ConditionalDistribution(scm, :T₁, [:W]) +G₀₂ = TMLE.ConditionalDistribution(scm, :T₂, [:W]) + +@testset "Test ConditionalDistribution" begin + @test TMLE.estimand_key(Q̄₀) == (:Y, Set([:W, :T₁, :T₂])) + @test TMLE.featurenames(Q̄₀) == [:T₁, :T₂, :W] + @test TMLE.variables(Q̄₀) == Set([:Y, :W, :T₁, :T₂]) + @test TMLE.variables(G₀₁) == Set([:W, :T₁]) +end + +@testset "Test CMRelevantFactors" begin + expected_key = ( + (:Y, Set([:W, :T₁, :T₂])), + (:T₁, Set([:W])), + (:T₂, Set([:W])) + ) + Q₀ = TMLE.CMRelevantFactors(scm, Q̄₀, (G₀₁, G₀₂)) + @test TMLE.estimand_key(Q₀) == expected_key + Q₀ = TMLE.CMRelevantFactors(scm, Q̄₀, (G₀₂, G₀₁)) + @test TMLE.estimand_key(Q₀) == expected_key + + @test TMLE.variables(Q₀) == Set([:Y, :W, :T₁, :T₂]) +end + +end + +true \ No newline at end of file diff --git a/test/estimates.jl b/test/estimates.jl new file mode 100644 index 00000000..1369c694 --- /dev/null +++ b/test/estimates.jl @@ -0,0 +1,65 @@ +module TestEstimates + +using Test +using TMLE +using MLJBase +using MLJGLMInterface +using DataFrames + +scm = SCM( + SE(:Y, [:T, :W]), + SE(:T₁, [:W]), + SE(:T₂, [:W]) +) + +@testset "Test MLConditionalDistribution" begin + X, y = make_regression(100, 4) + dataset = DataFrame(Y=y, T₁=X.x1, T₂=X.x2, W=X.x3) + estimand = ConditionalDistribution(scm, :Y, [:W, :T₁, :T₂]) + train_validation_indices = nothing + model = LinearRegressor() + estimate = TMLE.estimate( + estimand, dataset, model, train_validation_indices; + factors_cache=nothing, + verbosity=0 + ) + @test estimate isa TMLE.MLConditionalDistribution + expected_features = [:T₁, :T₂, :W] + @test fitted_params(estimate.machine).features == expected_features + ŷ = predict(estimate, dataset) + @test ŷ == predict(estimate.machine, dataset[!, expected_features]) + μ̂ = TMLE.expected_value(estimate, dataset) + @test μ̂ == mean.(ŷ) + @test all(0. <= x <= 1. for x in TMLE.likelihood(estimate, dataset)) +end + +@testset "Test SampleSplitMLConditionalDistribution" begin + n = 100 + nfolds = 3 + X, y = make_regression(n, 4) + dataset = DataFrame(Y=y, T₁=X.x1, T₂=X.x2, W=X.x3) + estimand = ConditionalDistribution(scm, :Y, [:W, :T₁, :T₂]) + train_validation_indices = collect(MLJBase.train_test_pairs(CV(nfolds=nfolds), 1:n, dataset)) + model = LinearRegressor() + estimate = TMLE.estimate( + estimand, dataset, model, train_validation_indices; + factors_cache=nothing, + verbosity=0 + ) + @test estimate isa TMLE.SampleSplitMLConditionalDistribution + expected_features = [:T₁, :T₂, :W] + @test all(fitted_params(mach).features == expected_features for mach in estimate.machines) + ŷ = predict(estimate, dataset) + μ̂ = TMLE.expected_value(estimate, dataset) + for foldid in 1:nfolds + train, val = train_validation_indices[foldid] + ŷfold = predict(estimate.machines[foldid], dataset[val, expected_features]) + @test ŷ[val] == ŷfold + @test μ̂[val] == mean.(ŷfold) + end + all(0. <= x <= 1. for x in TMLE.likelihood(estimate, dataset)) +end + +end + +true \ No newline at end of file From 0ff82c86fec31078372d123de8de24f30026c8f9 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 12 Sep 2023 17:39:08 +0100 Subject: [PATCH 067/151] add some tests --- src/estimands.jl | 2 ++ src/treatment_transformer.jl | 6 +++-- test/estimates.jl | 52 ++++++++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/src/estimands.jl b/src/estimands.jl index 16f5028e..b2f94716 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -139,6 +139,8 @@ end CMRelevantFactors(scm, outcome_mean, propensity_score::ConditionalDistribution) = CMRelevantFactors(scm, outcome_mean, (propensity_score,)) +CMRelevantFactors(scm; outcome_mean, propensity_score) = CMRelevantFactors(scm, outcome_mean, propensity_score) + string_repr(estimand::CMRelevantFactors) = string("Composite Factor: \n", "----------------\n- ", diff --git a/src/treatment_transformer.jl b/src/treatment_transformer.jl index 510a9f12..52d55cb4 100644 --- a/src/treatment_transformer.jl +++ b/src/treatment_transformer.jl @@ -22,7 +22,9 @@ function MLJBase.transform(model::TreatmentTransformer, fitresult, Xnew) Xt = MLJBase.transform(model.encoder, fitresult, Xnew) ordered_factors_names = [] ordered_factors_values = [] - for (colname, column) in pairs(Xnew) + + for colname in Tables.columnnames(Xnew) + column = Tables.getcolumn(Xnew, colname) if eltype(scitype(column)) <: OrderedFactor try push!(ordered_factors_values, float(column)) @@ -38,7 +40,7 @@ function MLJBase.transform(model::TreatmentTransformer, fitresult, Xnew) end ordered_factors_names = Tuple(ordered_factors_names) ordered_factors = NamedTuple{ordered_factors_names}(ordered_factors_values) - return merge(Xt, ordered_factors) + return merge(Tables.columntable(Xt), ordered_factors) end with_encoder(model; encoder=encoder()) = Pipeline(TreatmentTransformer(;encoder=encoder), model) \ No newline at end of file diff --git a/test/estimates.jl b/test/estimates.jl index 1369c694..69e0733b 100644 --- a/test/estimates.jl +++ b/test/estimates.jl @@ -5,6 +5,8 @@ using TMLE using MLJBase using MLJGLMInterface using DataFrames +using Distributions +using LogExpFunctions scm = SCM( SE(:Y, [:T, :W]), @@ -53,6 +55,8 @@ end μ̂ = TMLE.expected_value(estimate, dataset) for foldid in 1:nfolds train, val = train_validation_indices[foldid] + # The predictions on validation samples are made from + # the machine trained on the train sample ŷfold = predict(estimate.machines[foldid], dataset[val, expected_features]) @test ŷ[val] == ŷfold @test μ̂[val] == mean.(ŷfold) @@ -60,6 +64,54 @@ end all(0. <= x <= 1. for x in TMLE.likelihood(estimate, dataset)) end +@testset "Test MLCMRelevantFactors" begin + n = 100 + W = rand(Normal(), n) + T₁ = rand(n) .< logistic.(1 .- W) + Y = T₁ .+ W .+ rand(n) + dataset = DataFrame(Y=Y, W=W, T₁=categorical(T₁, ordered=true)) + models = (Y=with_encoder(LinearRegressor()), T₁=LinearBinaryClassifier()) + estimand = TMLE.CMRelevantFactors( + scm, + outcome_mean = TMLE.ConditionalDistribution(scm, :Y, Set([:T₁, :W])), + propensity_score = (TMLE.ConditionalDistribution(scm, :T₁, Set([:W])),) + ) + # No resampling + resampling = nothing + estimate = Distributions.estimate(estimand, resampling, models, dataset; + factors_cache=nothing, + verbosity=0 + ) + outcome_model_features = [:T₁, :W] + @test predict(estimate.outcome_mean, dataset) == predict(estimate.outcome_mean.machine, dataset[!, outcome_model_features]) + @test predict(estimate.propensity_score[1], dataset).prob_given_ref == predict(estimate.propensity_score[1].machine, dataset[!, [:W]]).prob_given_ref + + # No resampling + nfolds = 3 + resampling = CV(nfolds=nfolds) + estimate = Distributions.estimate(estimand, resampling, models, dataset; + factors_cache=nothing, + verbosity=0 + ) + train_validation_indices = estimate.outcome_mean.train_validation_indices + # Check all models share the same train/validation indices + @test train_validation_indices == estimate.propensity_score[1].train_validation_indices + outcome_model_features = [:T₁, :W] + ŷ = predict(estimate.outcome_mean, dataset) + t̂ = predict(estimate.propensity_score[1], dataset) + for foldid in 1:nfolds + train, val = train_validation_indices[foldid] + # The predictions on validation samples are made from + # the machine trained on the train sample + ŷfold = predict(estimate.outcome_mean.machines[foldid], dataset[val, outcome_model_features]) + @test ŷ[val] == ŷfold + t̂fold = predict(estimate.propensity_score[1].machines[foldid], dataset[val, [:W]]) + @test t̂fold.prob_given_ref == t̂[val].prob_given_ref + end + +end + + end true \ No newline at end of file From 3c16d835c5a2e869b41c12503dc7eb77dbe6c1e4 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sat, 23 Sep 2023 11:04:58 +0100 Subject: [PATCH 068/151] add more work --- src/TMLE.jl | 7 +- src/estimands.jl | 45 +++---- src/estimates.jl | 78 +----------- src/estimators.jl | 85 +++++++++++++ src/scm.jl | 12 ++ src/utils.jl | 9 +- ...imators.jl => estimators_and_estimates.jl} | 47 +++++++ test/estimands.jl | 37 +++--- test/estimates.jl | 117 ------------------ test/estimators_and_estimates.jl | 91 ++++++++++++++ test/missing_management.jl | 16 +-- test/utils.jl | 26 ---- 12 files changed, 288 insertions(+), 282 deletions(-) create mode 100644 src/estimators.jl rename test/counterfactual_mean_based/{estimators.jl => estimators_and_estimates.jl} (83%) delete mode 100644 test/estimates.jl create mode 100644 test/estimators_and_estimates.jl diff --git a/src/TMLE.jl b/src/TMLE.jl index 3b7a821f..703b43ca 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -33,19 +33,20 @@ export AverageTreatmentEffect, ATE export InteractionAverageTreatmentEffect, IATE export AVAILABLE_ESTIMANDS export fit!, optimize_ordering, optimize_ordering! -export tmle!, ose!, naive_plugin_estimate!, naive_plugin_estimate +export TMLEE, OSE, NAIVE export var, estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose export TreatmentTransformer, with_encoder export BackdoorAdjustment - + # ############################################################################# # INCLUDES # ############################################################################# include("utils.jl") -include("scm.jl") +# include("scm.jl") include("estimands.jl") +include("estimators.jl") include("estimates.jl") include("treatment_transformer.jl") include("adjustment.jl") diff --git a/src/estimands.jl b/src/estimands.jl index b2f94716..a7ffdb68 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -11,7 +11,7 @@ string_repr(estimand::Estimand) = estimand Base.show(io::IO, ::MIME"text/plain", estimand::Estimand) = println(io, string_repr(estimand)) -treatments(Ψ::Estimand) = collect(keys(Ψ.treatment)) +treatments(Ψ::Estimand) = collect(keys(Ψ.treatment_values)) AbsentLevelError(treatment_name, key, val, levels) = ArgumentError(string( "The treatment variable ", treatment_name, "'s, '", key, "' level: '", val, @@ -58,7 +58,7 @@ Makes sure the defined treatment levels are present in the dataset. function check_treatment_levels(Ψ::Estimand, dataset) for treatment_name in treatments(Ψ) treatment_levels = levels(Tables.getcolumn(dataset, treatment_name)) - treatment_settings = getproperty(Ψ.treatment, treatment_name) + treatment_settings = getproperty(Ψ.treatment_values, treatment_name) check_treatment_settings(treatment_settings, treatment_levels, treatment_name) end end @@ -66,7 +66,7 @@ end """ Function used to sort estimands for optimal estimation ordering. """ -function estimand_key end +key(estimand::Estimand) = estimand """ optimize_ordering!(estimands::Vector{<:Estimand}) @@ -90,26 +90,21 @@ optimize_ordering(estimands::Vector{<:Estimand}) = sort(estimands, by=estimand_k Defines a Conditional Distribution estimand ``(outcome, parents) → P(outcome|parents)``. """ struct ConditionalDistribution <: Estimand - scm::SCM outcome::Symbol - parents::Set{Symbol} - function ConditionalDistribution(scm, outcome, parents) + parents::Tuple{Vararg{Symbol}} + function ConditionalDistribution(outcome, parents) outcome = Symbol(outcome) - parents = Set(Symbol(x) for x in parents) + parents = unique_sorted_tuple(parents) outcome ∉ parents || throw(SelfReferringEquationError(outcome)) # Maybe check variables are in the SCM? - return new(scm, outcome, parents) + return new(outcome, parents) end end string_repr(estimand::ConditionalDistribution) = string("P₀(", estimand.outcome, " | ", join(estimand.parents, ", "), ")") -featurenames(estimand::ConditionalDistribution) = sort(collect(estimand.parents)) - -variables(estimand::ConditionalDistribution) = union(Set([estimand.outcome]), estimand.parents) - -estimand_key(cd::ConditionalDistribution) = (cd.outcome, cd.parents) +variables(estimand::ConditionalDistribution) = (estimand.outcome, estimand.parents...) ##################################################################### ### ExpectedValue ### @@ -131,15 +126,18 @@ Defines relevant factors that need to be estimated in order to estimate any Counterfactual Mean composite estimand (see `CMCompositeEstimand`). """ struct CMRelevantFactors <: Estimand - scm::SCM outcome_mean::ConditionalDistribution propensity_score::Tuple{Vararg{ConditionalDistribution}} end -CMRelevantFactors(scm, outcome_mean, propensity_score::ConditionalDistribution) = - CMRelevantFactors(scm, outcome_mean, (propensity_score,)) +CMRelevantFactors(outcome_mean, propensity_score::ConditionalDistribution) = + CMRelevantFactors(outcome_mean, (propensity_score,)) -CMRelevantFactors(scm; outcome_mean, propensity_score) = CMRelevantFactors(scm, outcome_mean, propensity_score) +CMRelevantFactors(outcome_mean, propensity_score) = + CMRelevantFactors(outcome_mean, propensity_score) + +CMRelevantFactors(;outcome_mean, propensity_score) = + CMRelevantFactors(outcome_mean, propensity_score) string_repr(estimand::CMRelevantFactors) = string("Composite Factor: \n", @@ -148,15 +146,4 @@ string_repr(estimand::CMRelevantFactors) = join((string_repr(f) for f in estimand.propensity_score), "\n- ")) variables(estimand::CMRelevantFactors) = - union(variables(estimand.outcome_mean), (variables(est) for est in estimand.propensity_score)...) - -""" - estimand_key(estimand::CMRelevantFactors) - -The key combines the outcome_mean's key anc the propensity scores' keys. -Because the order of treatment does not matter, we order them by treatment. -""" -estimand_key(estimand::CMRelevantFactors) = ( - estimand_key(estimand.outcome_mean), - sort((estimand_key(ps) for ps in estimand.propensity_score), by=x->x[1])... - ) \ No newline at end of file + Tuple(union(variables(estimand.outcome_mean), (variables(est) for est in estimand.propensity_score)...)) diff --git a/src/estimates.jl b/src/estimates.jl index 1c4e7115..0056424a 100644 --- a/src/estimates.jl +++ b/src/estimates.jl @@ -25,23 +25,11 @@ string_repr(estimate::MLConditionalDistribution) = string( Base.typename(typeof(estimate.machine.model)).wrapper ) -featurenames(estimate::MLConditionalDistribution) = featurenames(estimate.estimand) - function MLJBase.predict(estimate::MLConditionalDistribution, dataset) - X = selectcols(dataset, featurenames(estimate)) + X = selectcols(dataset, estimate.estimand.parents) return predict(estimate.machine, X) end -function Distributions.estimate(factor::ConditionalDistribution, dataset, model, train_validation_indices::Nothing; - factors_cache=nothing, - verbosity=1) - relevant_dataset = TMLE.nomissing(dataset, TMLE.variables(factor)) - X = selectcols(relevant_dataset, TMLE.featurenames(factor)) - y = Tables.getcolumn(relevant_dataset, factor.outcome) - mach = machine(model, X, y) - fit!(mach, verbosity=verbosity-1) - return MLConditionalDistribution(factor, mach) -end ##################################################################### ### SampleSplitMLConditionalDistribution ### @@ -63,7 +51,7 @@ string_repr(estimate::SampleSplitMLConditionalDistribution) = ) function MLJBase.predict(estimate::SampleSplitMLConditionalDistribution, dataset) - X = selectcols(dataset, featurenames(estimate.estimand)) + X = selectcols(dataset, estimate.estimand.parents) ŷs = [] for (fold, (_, validation_indices)) in enumerate(estimate.train_validation_indices) Xval = selectrows(X, validation_indices) @@ -72,24 +60,6 @@ function MLJBase.predict(estimate::SampleSplitMLConditionalDistribution, dataset return vcat(ŷs...) end -function Distributions.estimate(factor::ConditionalDistribution, dataset, model, train_validation_indices; - factors_cache=nothing, - verbosity=1) - relevant_dataset = TMLE.selectcols(dataset, TMLE.variables(factor)) - nfolds = size(train_validation_indices, 1) - machines = Vector{Machine}(undef, nfolds) - features = TMLE.featurenames(factor) - for (index, (train_indices, _)) in enumerate(train_validation_indices) - train_dataset = selectrows(relevant_dataset, train_indices) - Xtrain = selectcols(train_dataset, features) - ytrain = Tables.getcolumn(train_dataset, factor.outcome) - mach = machine(model, Xtrain, ytrain) - fit!(mach, verbosity=verbosity-1) - machines[index] = mach - end - return SampleSplitMLConditionalDistribution(factor, train_validation_indices, machines) -end - ##################################################################### ### ConditionalDistributionEstimate ### ##################################################################### @@ -120,50 +90,6 @@ struct MLCMRelevantFactors <: Estimate propensity_score::Tuple{Vararg{ConditionalDistributionEstimate}} end -get_train_validation_indices(resampling::Nothing, factors, dataset) = nothing - -function get_train_validation_indices(resampling::ResamplingStrategy, factors, dataset) - relevant_columns = collect(TMLE.variables(factors)) - outcome_variable = factors.outcome_mean.outcome - feature_variables = filter(x -> x !== outcome_variable, relevant_columns) - return collect(MLJBase.train_test_pairs( - resampling, - 1:nrows(dataset), - selectcols(dataset, feature_variables), - Tables.getcolumn(dataset, outcome_variable) - )) -end - -function Distributions.estimate(relevant_factors::CMRelevantFactors, resampling, models, dataset; - factors_cache=nothing, - verbosity=1) - # Get train validation indices - train_validation_indices = get_train_validation_indices(resampling, relevant_factors, dataset) - # Fit propensity score - propensity_score_estimate = Tuple( - estimate( - factor, - dataset, - models[factor.outcome], - train_validation_indices; - verbosity=verbosity, - factors_cache=factors_cache) - for factor in relevant_factors.propensity_score - ) - # Fit outcome mean - outcome_mean = relevant_factors.outcome_mean - model = models[outcome_mean.outcome] - outcome_mean_estimate = TMLE.estimate( - outcome_mean, - dataset, - model, - train_validation_indices; - verbosity=verbosity, - factors_cache=factors_cache) - - return MLCMRelevantFactors(relevant_factors, outcome_mean_estimate, propensity_score_estimate) -end - string_repr(estimate::MLCMRelevantFactors) = string( "Composite Factor Estimate: \n", "-------------------------\n- ", diff --git a/src/estimators.jl b/src/estimators.jl new file mode 100644 index 00000000..6857a1e0 --- /dev/null +++ b/src/estimators.jl @@ -0,0 +1,85 @@ +abstract type Estimator end + +##################################################################### +### MLConditionalDistributionEstimator ### +##################################################################### + +struct MLConditionalDistributionEstimator <: Estimator + model +end + +function (estimator::MLConditionalDistributionEstimator)(estimand, dataset; cache=Dict(), verbosity=1) + # Lookup in cache + cache_key = key(estimand, estimator) + if haskey(cache, cache_key) + estimate = cache[cache_key] + verbosity > 0 && @info(string("Reusing estimate for: ", string_repr(estimand))) + return estimate + end + verbosity > 0 && @info(string("Estimating: ", string_repr(estimand))) + # Otherwise estimate + relevant_dataset = TMLE.nomissing(dataset, TMLE.variables(estimand)) + # Fit Conditional DIstribution using MLJ + X = selectcols(relevant_dataset, estimand.parents) + y = Tables.getcolumn(relevant_dataset, estimand.outcome) + mach = machine(estimator.model, X, y) + fit!(mach, verbosity=verbosity-1) + # Build estimate + estimate = MLConditionalDistribution(estimand, mach) + # Update cache + cache[cache_key] = estimate + + return estimate +end + +key(estimator::MLConditionalDistributionEstimator) = + (MLConditionalDistributionEstimator, estimator.model) + +##################################################################### +### SampleSplitMLConditionalDistributionEstimator ### +##################################################################### + +struct SampleSplitMLConditionalDistributionEstimator <: Estimator + model + train_validation_indices +end + +function (estimator::SampleSplitMLConditionalDistributionEstimator)(estimand, dataset; cache=Dict(), verbosity=1) + # Lookup in cache + cache_key = key(estimand, estimator) + if haskey(cache, cache_key) + estimate = cache[cache_key] + verbosity > 0 && @info(string("Reusing estimate for: ", string_repr(estimand))) + return estimate + end + # Otherwise estimate + verbosity > 0 && @info(string("Estimating: ", string_repr(estimand))) + + relevant_dataset = TMLE.selectcols(dataset, TMLE.variables(estimand)) + nfolds = size(estimator.train_validation_indices, 1) + machines = Vector{Machine}(undef, nfolds) + # Fit Conditional DIstribution on each training split using MLJ + for (index, (train_indices, _)) in enumerate(estimator.train_validation_indices) + train_dataset = selectrows(relevant_dataset, train_indices) + Xtrain = selectcols(train_dataset, estimand.parents) + ytrain = Tables.getcolumn(train_dataset, estimand.outcome) + mach = machine(estimator.model, Xtrain, ytrain) + fit!(mach, verbosity=verbosity-1) + machines[index] = mach + end + # Build estimate + estimate = SampleSplitMLConditionalDistribution(estimand, estimator.train_validation_indices, machines) + # Update cache + cache[cache_key] = estimate + + return estimate +end + +key(estimator::SampleSplitMLConditionalDistributionEstimator) = + (MLConditionalDistributionEstimator, estimator.model, estimator.train_validation_indices) + +ConditionalDistributionEstimator(train_validation_indices::Nothing, model) = + MLConditionalDistributionEstimator(model) + +ConditionalDistributionEstimator(train_validation_indices, model) = + SampleSplitMLConditionalDistributionEstimator(model, train_validation_indices) diff --git a/src/scm.jl b/src/scm.jl index c295ae65..cfd0bc4e 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -174,6 +174,18 @@ function is_upstream(var₁, var₂, scm::SCM) end end +VariableNotAChildInSCMError(variable) = ArgumentError(string("Variable ", variable, " is not associated with a Structural Equation in the SCM.")) +TreatmentMustBeInOutcomeParentsError(variable) = ArgumentError(string("Treatment variable ", variable, " must be a parent of the outcome.")) + +function check_parameter_against_scm(scm::SCM, outcome, treatment) + eqs = equations(scm) + haskey(eqs, outcome) || throw(VariableNotAChildInSCMError(outcome)) + for treatment_variable in keys(treatment) + haskey(eqs, treatment_variable) || throw(VariableNotAChildInSCMError(treatment_variable)) + is_upstream(treatment_variable, outcome, scm) || throw(TreatmentMustBeInOutcomeParentsError(treatment_variable)) + end +end + ##################################################################### ### StaticConfoundedModel ### ##################################################################### diff --git a/src/utils.jl b/src/utils.jl index ee95122f..d7b05a56 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,6 +2,13 @@ ## General Utilities ############################################################################### +key(estimand, estimator) = (key(estimand), key(estimator)) + +unique_sorted_tuple(iter) = Tuple(sort(unique(Symbol(x) for x in iter))) + +choose_initial_dataset(dataset, nomissing_dataset, resampling::Nothing) = dataset +choose_initial_dataset(dataset, nomissing_dataset, resampling) = nomissing_dataset + selectcols(data, cols) = data |> TableOperations.select(cols...) |> Tables.columntable function logit!(v) @@ -36,14 +43,12 @@ function indicator_values(indicators, T) return indic end - expected_value(ŷ::UnivariateFiniteVector{Multiclass{2}}) = pdf.(ŷ, levels(first(ŷ))[2]) expected_value(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = mean.(ŷ) expected_value(ŷ::AbstractVector{<:Real}) = ŷ training_expected_value(Q::Machine, dataset) = expected_value(predict(Q, dataset)) - function counterfactualTreatment(vals, T) Tnames = Tables.columnnames(T) n = nrows(T) diff --git a/test/counterfactual_mean_based/estimators.jl b/test/counterfactual_mean_based/estimators_and_estimates.jl similarity index 83% rename from test/counterfactual_mean_based/estimators.jl rename to test/counterfactual_mean_based/estimators_and_estimates.jl index 6ce83b3f..c93387df 100644 --- a/test/counterfactual_mean_based/estimators.jl +++ b/test/counterfactual_mean_based/estimators_and_estimates.jl @@ -57,6 +57,53 @@ end table_types = (Tables.columntable, DataFrame) +@testset "Test MLCMRelevantFactors" begin + n = 100 + W = rand(Normal(), n) + T₁ = rand(n) .< logistic.(1 .- W) + Y = T₁ .+ W .+ rand(n) + dataset = DataFrame(Y=Y, W=W, T₁=categorical(T₁, ordered=true)) + models = (Y=with_encoder(LinearRegressor()), T₁=LinearBinaryClassifier()) + estimand = TMLE.CMRelevantFactors( + scm, + outcome_mean = TMLE.ConditionalDistribution(scm, :Y, Set([:T₁, :W])), + propensity_score = (TMLE.ConditionalDistribution(scm, :T₁, Set([:W])),) + ) + # No resampling + resampling = nothing + estimate = Distributions.estimate(estimand, resampling, models, dataset; + factors_cache=nothing, + verbosity=0 + ) + outcome_model_features = [:T₁, :W] + @test predict(estimate.outcome_mean, dataset) == predict(estimate.outcome_mean.machine, dataset[!, outcome_model_features]) + @test predict(estimate.propensity_score[1], dataset).prob_given_ref == predict(estimate.propensity_score[1].machine, dataset[!, [:W]]).prob_given_ref + + # No resampling + nfolds = 3 + resampling = CV(nfolds=nfolds) + estimate = Distributions.estimate(estimand, resampling, models, dataset; + factors_cache=nothing, + verbosity=0 + ) + train_validation_indices = estimate.outcome_mean.train_validation_indices + # Check all models share the same train/validation indices + @test train_validation_indices == estimate.propensity_score[1].train_validation_indices + outcome_model_features = [:T₁, :W] + ŷ = predict(estimate.outcome_mean, dataset) + t̂ = predict(estimate.propensity_score[1], dataset) + for foldid in 1:nfolds + train, val = train_validation_indices[foldid] + # The predictions on validation samples are made from + # the machine trained on the train sample + ŷfold = predict(estimate.outcome_mean.machines[foldid], dataset[val, outcome_model_features]) + @test ŷ[val] == ŷfold + t̂fold = predict(estimate.propensity_score[1].machines[foldid], dataset[val, [:W]]) + @test t̂fold.prob_given_ref == t̂[val].prob_given_ref + end + +end + @testset "Test Warm restart: ATE single treatment, $tt" for tt in table_types dataset, scm = build_dataset_and_scm(;n=50_000) dataset = tt(dataset) diff --git a/test/estimands.jl b/test/estimands.jl index 73eb32ab..ee751731 100644 --- a/test/estimands.jl +++ b/test/estimands.jl @@ -3,34 +3,29 @@ module TestEstimands using TMLE using Test -scm = SCM( - SE(:Y, [:T, :W]), - SE(:T₁, [:W]), - SE(:T₂, [:W]) - ) -Q̄₀ = TMLE.ExpectedValue(scm, :Y, [:T₁, :T₂, :W]) -G₀₁ = TMLE.ConditionalDistribution(scm, :T₁, [:W]) -G₀₂ = TMLE.ConditionalDistribution(scm, :T₂, [:W]) @testset "Test ConditionalDistribution" begin - @test TMLE.estimand_key(Q̄₀) == (:Y, Set([:W, :T₁, :T₂])) - @test TMLE.featurenames(Q̄₀) == [:T₁, :T₂, :W] - @test TMLE.variables(Q̄₀) == Set([:Y, :W, :T₁, :T₂]) - @test TMLE.variables(G₀₁) == Set([:W, :T₁]) + distr = ConditionalDistribution("Y", ["C", 1, :A, ]) + @test distr.outcome === :Y + @test distr.parents === (Symbol("1"), :A, :C) + @test TMLE.variables(distr) == (:Y, Symbol("1"), :A, :C) end @testset "Test CMRelevantFactors" begin - expected_key = ( - (:Y, Set([:W, :T₁, :T₂])), - (:T₁, Set([:W])), - (:T₂, Set([:W])) + η = TMLE.CMRelevantFactors( + outcome_mean=ExpectedValue(:Y, [:T, :W]), + propensity_score=ConditionalDistribution(:T, [:W]) ) - Q₀ = TMLE.CMRelevantFactors(scm, Q̄₀, (G₀₁, G₀₂)) - @test TMLE.estimand_key(Q₀) == expected_key - Q₀ = TMLE.CMRelevantFactors(scm, Q̄₀, (G₀₂, G₀₁)) - @test TMLE.estimand_key(Q₀) == expected_key + @test TMLE.variables(η) == (:Y, :T, :W) - @test TMLE.variables(Q₀) == Set([:Y, :W, :T₁, :T₂]) + η = TMLE.CMRelevantFactors( + outcome_mean=ExpectedValue(:Y, [:T, :W]), + propensity_score=( + ConditionalDistribution(:T₁, [:W₁]), + ConditionalDistribution(:T₂, [:W₂₁, :W₂₂]) + ) + ) + @test TMLE.variables(η) == (:Y, :T, :W, :T₁, :W₁, :T₂, :W₂₁, :W₂₂) end end diff --git a/test/estimates.jl b/test/estimates.jl deleted file mode 100644 index 69e0733b..00000000 --- a/test/estimates.jl +++ /dev/null @@ -1,117 +0,0 @@ -module TestEstimates - -using Test -using TMLE -using MLJBase -using MLJGLMInterface -using DataFrames -using Distributions -using LogExpFunctions - -scm = SCM( - SE(:Y, [:T, :W]), - SE(:T₁, [:W]), - SE(:T₂, [:W]) -) - -@testset "Test MLConditionalDistribution" begin - X, y = make_regression(100, 4) - dataset = DataFrame(Y=y, T₁=X.x1, T₂=X.x2, W=X.x3) - estimand = ConditionalDistribution(scm, :Y, [:W, :T₁, :T₂]) - train_validation_indices = nothing - model = LinearRegressor() - estimate = TMLE.estimate( - estimand, dataset, model, train_validation_indices; - factors_cache=nothing, - verbosity=0 - ) - @test estimate isa TMLE.MLConditionalDistribution - expected_features = [:T₁, :T₂, :W] - @test fitted_params(estimate.machine).features == expected_features - ŷ = predict(estimate, dataset) - @test ŷ == predict(estimate.machine, dataset[!, expected_features]) - μ̂ = TMLE.expected_value(estimate, dataset) - @test μ̂ == mean.(ŷ) - @test all(0. <= x <= 1. for x in TMLE.likelihood(estimate, dataset)) -end - -@testset "Test SampleSplitMLConditionalDistribution" begin - n = 100 - nfolds = 3 - X, y = make_regression(n, 4) - dataset = DataFrame(Y=y, T₁=X.x1, T₂=X.x2, W=X.x3) - estimand = ConditionalDistribution(scm, :Y, [:W, :T₁, :T₂]) - train_validation_indices = collect(MLJBase.train_test_pairs(CV(nfolds=nfolds), 1:n, dataset)) - model = LinearRegressor() - estimate = TMLE.estimate( - estimand, dataset, model, train_validation_indices; - factors_cache=nothing, - verbosity=0 - ) - @test estimate isa TMLE.SampleSplitMLConditionalDistribution - expected_features = [:T₁, :T₂, :W] - @test all(fitted_params(mach).features == expected_features for mach in estimate.machines) - ŷ = predict(estimate, dataset) - μ̂ = TMLE.expected_value(estimate, dataset) - for foldid in 1:nfolds - train, val = train_validation_indices[foldid] - # The predictions on validation samples are made from - # the machine trained on the train sample - ŷfold = predict(estimate.machines[foldid], dataset[val, expected_features]) - @test ŷ[val] == ŷfold - @test μ̂[val] == mean.(ŷfold) - end - all(0. <= x <= 1. for x in TMLE.likelihood(estimate, dataset)) -end - -@testset "Test MLCMRelevantFactors" begin - n = 100 - W = rand(Normal(), n) - T₁ = rand(n) .< logistic.(1 .- W) - Y = T₁ .+ W .+ rand(n) - dataset = DataFrame(Y=Y, W=W, T₁=categorical(T₁, ordered=true)) - models = (Y=with_encoder(LinearRegressor()), T₁=LinearBinaryClassifier()) - estimand = TMLE.CMRelevantFactors( - scm, - outcome_mean = TMLE.ConditionalDistribution(scm, :Y, Set([:T₁, :W])), - propensity_score = (TMLE.ConditionalDistribution(scm, :T₁, Set([:W])),) - ) - # No resampling - resampling = nothing - estimate = Distributions.estimate(estimand, resampling, models, dataset; - factors_cache=nothing, - verbosity=0 - ) - outcome_model_features = [:T₁, :W] - @test predict(estimate.outcome_mean, dataset) == predict(estimate.outcome_mean.machine, dataset[!, outcome_model_features]) - @test predict(estimate.propensity_score[1], dataset).prob_given_ref == predict(estimate.propensity_score[1].machine, dataset[!, [:W]]).prob_given_ref - - # No resampling - nfolds = 3 - resampling = CV(nfolds=nfolds) - estimate = Distributions.estimate(estimand, resampling, models, dataset; - factors_cache=nothing, - verbosity=0 - ) - train_validation_indices = estimate.outcome_mean.train_validation_indices - # Check all models share the same train/validation indices - @test train_validation_indices == estimate.propensity_score[1].train_validation_indices - outcome_model_features = [:T₁, :W] - ŷ = predict(estimate.outcome_mean, dataset) - t̂ = predict(estimate.propensity_score[1], dataset) - for foldid in 1:nfolds - train, val = train_validation_indices[foldid] - # The predictions on validation samples are made from - # the machine trained on the train sample - ŷfold = predict(estimate.outcome_mean.machines[foldid], dataset[val, outcome_model_features]) - @test ŷ[val] == ŷfold - t̂fold = predict(estimate.propensity_score[1].machines[foldid], dataset[val, [:W]]) - @test t̂fold.prob_given_ref == t̂[val].prob_given_ref - end - -end - - -end - -true \ No newline at end of file diff --git a/test/estimators_and_estimates.jl b/test/estimators_and_estimates.jl new file mode 100644 index 00000000..7ad010ab --- /dev/null +++ b/test/estimators_and_estimates.jl @@ -0,0 +1,91 @@ +module TestEstimatorsAndEstimates + +using Test +using TMLE +using MLJBase +using DataFrames +using MLJGLMInterface + +verbosity = 1 +n = 100 +X, y = make_regression(n, 4) +dataset = DataFrame(Y=y, T₁=X.x1, T₂=X.x2, W=X.x3) +estimand = ConditionalDistribution(:Y, [:W, :T₁, :T₂]) +fit_log = string("Estimating: ", TMLE.string_repr(estimand)) +reuse_log = string("Reusing estimate for: ", TMLE.string_repr(estimand)) + +@testset "Test MLConditionalDistributionEstimator" begin + estimator = TMLE.MLConditionalDistributionEstimator(LinearRegressor()) + # Fitting with no cache + cache = Dict() + estimate = @test_logs (:info, fit_log) estimator(estimand, dataset; cache=cache, verbosity=verbosity) + expected_features = collect(estimand.parents) + @test estimate isa TMLE.MLConditionalDistribution + @test fitted_params(estimate.machine).features == expected_features + ŷ = predict(estimate, dataset) + @test ŷ == predict(estimate.machine, dataset[!, expected_features]) + μ̂ = TMLE.expected_value(estimate, dataset) + @test μ̂ == mean.(ŷ) + @test all(0. <= x <= 1. for x in TMLE.likelihood(estimate, dataset)) + # Uses the cache instead of fitting + new_estimator = TMLE.MLConditionalDistributionEstimator(LinearRegressor()) + @test TMLE.key(new_estimator) == TMLE.key(estimator) + @test_logs (:info, reuse_log) estimator(estimand, dataset; cache=cache, verbosity=verbosity) + # Changing the model leads to refit + new_estimator = TMLE.MLConditionalDistributionEstimator(LinearRegressor(fit_intercept=false)) + @test_logs (:info, fit_log) new_estimator(estimand, dataset; cache=cache, verbosity=verbosity) + # The cache contains two estimators for the estimand + @test length(cache) == 2 +end + +@testset "Test SampleSplitMLConditionalDistributionEstimator" begin + nfolds = 3 + train_validation_indices = collect(MLJBase.train_test_pairs(CV(nfolds=nfolds), 1:n, dataset)) + model = LinearRegressor() + estimator = TMLE.SampleSplitMLConditionalDistributionEstimator( + LinearRegressor(), + train_validation_indices + ) + cache = Dict() + estimate = @test_logs (:info, fit_log) estimator(estimand, dataset;cache=cache, verbosity=verbosity) + @test estimate isa TMLE.SampleSplitMLConditionalDistribution + expected_features = collect(estimand.parents) + @test all(fitted_params(mach).features == expected_features for mach in estimate.machines) + ŷ = predict(estimate, dataset) + μ̂ = TMLE.expected_value(estimate, dataset) + for foldid in 1:nfolds + train, val = train_validation_indices[foldid] + # The predictions on validation samples are made from + # the machine trained on the train sample + ŷfold = predict(estimate.machines[foldid], dataset[val, expected_features]) + @test ŷ[val] == ŷfold + @test μ̂[val] == mean.(ŷfold) + end + all(0. <= x <= 1. for x in TMLE.likelihood(estimate, dataset)) + # Uses the cache instead of fitting + new_estimator = TMLE.SampleSplitMLConditionalDistributionEstimator( + LinearRegressor(), + train_validation_indices + ) + @test TMLE.key(new_estimator) == TMLE.key(estimator) + @test_logs (:info, reuse_log) estimator(estimand, dataset;cache=cache, verbosity=verbosity) + # Changing the model leads to refit + new_estimator = TMLE.SampleSplitMLConditionalDistributionEstimator( + LinearRegressor(fit_intercept=false), + train_validation_indices + ) + @test_logs (:info, fit_log) new_estimator(estimand, dataset; cache=cache, verbosity=verbosity) + # Changing the train/validation splits leads to refit + train_validation_indices = collect(MLJBase.train_test_pairs(CV(nfolds=4), 1:n, dataset)) + new_estimator = TMLE.SampleSplitMLConditionalDistributionEstimator( + LinearRegressor(), + train_validation_indices + ) + @test_logs (:info, fit_log) new_estimator(estimand, dataset; cache=cache, verbosity=verbosity) + # The cache contains 3 estimators for the estimand + @test length(cache) == 3 +end + +end + +true \ No newline at end of file diff --git a/test/missing_management.jl b/test/missing_management.jl index 7211973a..6125deb8 100644 --- a/test/missing_management.jl +++ b/test/missing_management.jl @@ -20,13 +20,11 @@ function dataset_with_missing_and_ordered_treatment(;n=1000) dataset.W[1:5] .= missing dataset.T[6:10] .= missing dataset.Y[11:15] .= missing - scm = StaticConfoundedModel(:Y, :T, :W, treatment_model = LogisticClassifier(lambda=0)) - return dataset, scm + return dataset end - @testset "Test nomissing" begin - dataset, scm = dataset_with_missing_and_ordered_treatment(;n=100) + dataset = dataset_with_missing_and_ordered_treatment(;n=100) # filter missing rows based on W column filtered = TMLE.nomissing(dataset, [:W]) @test filtered.W == dataset.W[6:end] @@ -42,12 +40,14 @@ end end @testset "Test estimation with missing values and ordered factor treatment" begin - dataset, scm = dataset_with_missing_and_ordered_treatment(;n=1000) + dataset = dataset_with_missing_and_ordered_treatment(;n=1000) Ψ = ATE( - scm, outcome=:Y, - treatment=(T=(case=1, control=0),)) - tmle_result, fluctuation_mach = tmle!(Ψ, dataset; verbosity=0) + treatment_values=(T=(case=1, control=0),), + treatment_confounders=(T=[:W],)) + models=(Y=with_encoder(LinearRegressor()), T=LogisticClassifier(lambda=0)) + tmle = TMLEE(models) + tmle_result, fluctuation_mach = tmle(Ψ, dataset; verbosity=0) test_coverage(tmle_result, 1) test_fluct_decreases_risk(Ψ, fluctuation_mach) end diff --git a/test/utils.jl b/test/utils.jl index f1afcf66..2ac87525 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -8,32 +8,6 @@ using CategoricalArrays using MLJLinearModels using MLJModels -@testset "Test estimand validation" begin - dataset = ( - W₁ = [1., 2., 3.], - T = categorical([1, 2, 3]), - Y = [1., 2., 3.] - ) - # Treatment values absent from dataset - scm = SCM( - SE(:T, [:W₁]), - SE(:Y, [:T, :W₁]) - ) - Ψ = ATE( - scm, - outcome=:Y, - treatment=(T=(case=1, control=0),) - ) - @test_throws TMLE.AbsentLevelError("T", "control", 0, [1, 2, 3]) TMLE.check_treatment_levels(Ψ, dataset) - - Ψ = CM( - scm, - outcome=:Y, - treatment=(T=0,), - ) - @test_throws TMLE.AbsentLevelError("T", 0, [1, 2, 3]) TMLE.check_treatment_levels(Ψ, dataset) -end - @testset "Test expected_value" begin n = 100 X = MLJBase.table(rand(n, 3)) From efdb459b5ba61c294d59cc5d3bb1375f375a8ecd Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sat, 23 Sep 2023 11:25:16 +0100 Subject: [PATCH 069/151] fix estimands --- src/counterfactual_mean_based/estimands.jl | 255 +++++------------ test/counterfactual_mean_based/estimands.jl | 287 ++------------------ 2 files changed, 88 insertions(+), 454 deletions(-) diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index b66afb62..d0819a9d 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -1,201 +1,92 @@ -##################################################################### -### Counterfactual Mean ### -##################################################################### - -""" -# Counterfactual Mean / CM - -## Definition - -``CM(Y, T=t) = E[Y|do(T=t)]`` - -## Constructors - -- CM(scm::SCM; outcome::Symbol, treatment::NamedTuple) -- CM( - outcome::Symbol, - treatment::NamedTuple, - confounders::Union{Symbol, AbstractVector{Symbol}}, - covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, - outcome_model = with_encoder(LinearRegressor()), - treatment_model = LinearBinaryClassifier() +ESTIMANDS_DOCS = Dict( + :CM => (formula="``CM(Y, T=t) = E[Y|do(T=t)]``",), + :ATE => (formula="``ATE(Y, T, case, control) = E[Y|do(T=case)] - E[Y|do(T=control)``",), + :IATE => (formula="``IATE = E[Y|do(T₁=1, T₂=1)] - E[Y|do(T₁=1, T₂=0)] - E[Y|do(T₁=0, T₂=1)] + E[Y|do(T₁=0, T₂=0)]``",) ) -## Example - -```julia -Ψ = CM(scm, outcome=:Y, treatment=(T=1,)) - -``` -""" -struct CounterfactualMean <: Estimand - scm::StructuralCausalModel - outcome::Symbol - treatment::NamedTuple - function CounterfactualMean(scm, outcome, treatment) - check_parameter_against_scm(scm, outcome, treatment) - return new(scm, outcome, treatment) - end -end - -const CM = CounterfactualMean - -indicator_fns(Ψ::CM) = Dict(values(Ψ.treatment) => 1.) - - -##################################################################### -### Average Treatment Effect ### -##################################################################### -""" -# Average Treatment Effect / ATE - -## Definition - -``ATE(Y, T, case, control) = E[Y|do(T=case)] - E[Y|do(T=control)`` +# Define constructors/name for CMCompositeEstimand types -## Constructors +for (typename, (formula,)) ∈ ESTIMANDS_DOCS + ex = quote + # """ + # # $(typename) + + # ## Definition + + # For two treatments with case/control settings (1, 0): + + # $(formula) + + # ## Constructors + + # - $(typename)(outcome, treatment, confounders; extra_outcome_covariates=Set{Symbol}()) + # - $(typename)(; + # outcome, + # treatment, + # confounders, + # extra_outcome_covariates + # ) + + # ## Example + + # ```julia + # Ψ = $(typename)(outcome=:Y, treatment=(T₁=(case=1, control=0), T₂=(case=1,control=0), confounders=[:W₁, :W₂]) + # ``` + # """ + struct $(typename) <: Estimand + outcome::Symbol + treatment_values::NamedTuple + treatment_confounders::NamedTuple + outcome_extra_covariates::Tuple{Vararg{Symbol}} + + function $(typename)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) + outcome = Symbol(outcome) + treatment_variables = Tuple(keys(treatment_values)) + treatment_confounders = NamedTuple{treatment_variables}([unique_sorted_tuple(treatment_confounders[T]) for T ∈ treatment_variables]) + outcome_extra_covariates = unique_sorted_tuple(outcome_extra_covariates) + return new(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) + end + end -- ATE(scm::SCM; outcome::Symbol, treatment::NamedTuple) -- ATE( - outcome::Symbol, - treatment::NamedTuple, - confounders::Union{Symbol, AbstractVector{Symbol}}, - covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, - outcome_model = with_encoder(LinearRegressor()), - treatment_model = LinearBinaryClassifier() -) + $(typename)(outcome, treatment_values, treatment_confounders; outcome_extra_covariates=()) = + $(typename)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) + + $(typename)(;outcome, treatment_values, treatment_confounders, outcome_extra_covariates=()) = + $(typename)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) -## Example - -```julia -Ψ = ATE(scm, outcome=:Y, treatment=(T=(case=1,control=0),) -``` -""" -struct AverageTreatmentEffect <: Estimand - scm::StructuralCausalModel - outcome::Symbol - treatment::NamedTuple - function AverageTreatmentEffect(scm, outcome, treatment) - check_parameter_against_scm(scm, outcome, treatment) - return new(scm, outcome, treatment) + name(::Type{$(typename)}) = string($(typename)) end + eval(ex) end -const ATE = AverageTreatmentEffect +CMCompositeEstimand = Union{(eval(x) for x in keys(ESTIMANDS_DOCS))...} + +indicator_fns(Ψ::CM) = Dict(values(Ψ.treatment_values) => 1.) function indicator_fns(Ψ::ATE) case = [] control = [] - for treatment in Ψ.treatment + for treatment in Ψ.treatment_values push!(case, treatment.case) push!(control, treatment.control) end return Dict(Tuple(case) => 1., Tuple(control) => -1.) end -##################################################################### -### Interaction Average Treatment Effect ### -##################################################################### - -""" -# Interaction Average Treatment Effect / IATE - -## Definition - -For two treatments with case/control settings (1, 0): - -``IATE = E[Y|do(T₁=1, T₂=1)] - E[Y|do(T₁=1, T₂=0)] - E[Y|do(T₁=0, T₂=1)] + E[Y|do(T₁=0, T₂=0)]`` - -## Constructors - -- IATE(scm::SCM; outcome::Symbol, treatment::NamedTuple) -- IATE( - outcome::Symbol, - treatment::NamedTuple, - confounders::Union{Symbol, AbstractVector{Symbol}}, - covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, - outcome_model = with_encoder(LinearRegressor()), - treatment_model = LinearBinaryClassifier() -) - -## Example - -```julia -Ψ = IATE(scm, outcome=:Y, treatment=(T₁=(case=1,control=0), T₂=(case=1,control=0)) -``` -""" -struct InteractionAverageTreatmentEffect <: Estimand - scm::StructuralCausalModel - outcome::Symbol - treatment::NamedTuple - function InteractionAverageTreatmentEffect(scm, outcome, treatment) - check_parameter_against_scm(scm, outcome, treatment) - return new(scm, outcome, treatment) - end -end - -const IATE = InteractionAverageTreatmentEffect - -ncases(value, Ψ::IATE) = sum(value[i] == Ψ.treatment[i].case for i in eachindex(value)) +ncases(value, Ψ::IATE) = sum(value[i] == Ψ.treatment_values[i].case for i in eachindex(value)) function indicator_fns(Ψ::IATE) N = length(treatments(Ψ)) key_vals = Pair[] - for cf in Iterators.product((values(Ψ.treatment[T]) for T in treatments(Ψ))...) + for cf in Iterators.product((values(Ψ.treatment_values[T]) for T in treatments(Ψ))...) push!(key_vals, cf => float((-1)^(N - ncases(cf, Ψ)))) end return Dict(key_vals...) end -##################################################################### -### Methods ### -##################################################################### - -AVAILABLE_ESTIMANDS = (CM, ATE, IATE) -CMCompositeTypenames = [:CM, :ATE, :IATE] -CMCompositeEstimand = Union{(eval(x) for x in CMCompositeTypenames)...} - -# Define constructors/name for CMCompositeEstimand types -for typename in CMCompositeTypenames - ex = quote - name(::Type{$(typename)}) = string($(typename)) - - $(typename)(scm::SCM; outcome::Symbol, treatment::NamedTuple) = $(typename)(scm, outcome, treatment) - - function $(typename)(; - outcome::Symbol, - treatment::NamedTuple, - confounders::Union{Symbol, AbstractVector{Symbol}}, - covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing - ) - scm = StaticConfoundedModel( - [outcome], - collect(keys(treatment)), - confounders; - covariates=covariates - ) - return $(typename)(scm; outcome=outcome, treatment=treatment) - end - end - - eval(ex) -end - -VariableNotAChildInSCMError(variable) = ArgumentError(string("Variable ", variable, " is not associated with a Structural Equation in the SCM.")) -TreatmentMustBeInOutcomeParentsError(variable) = ArgumentError(string("Treatment variable ", variable, " must be a parent of the outcome.")) - -function check_parameter_against_scm(scm::SCM, outcome, treatment) - eqs = equations(scm) - haskey(eqs, outcome) || throw(VariableNotAChildInSCMError(outcome)) - for treatment_variable in keys(treatment) - haskey(eqs, treatment_variable) || throw(VariableNotAChildInSCMError(treatment_variable)) - is_upstream(treatment_variable, outcome, scm) || throw(TreatmentMustBeInOutcomeParentsError(treatment_variable)) - end -end - -function get_relevant_factors(Ψ::CMCompositeEstimand; adjustment_method=BackdoorAdjustment()) - outcome_model = ExpectedValue(Ψ.scm, outcome(Ψ), outcome_parents(adjustment_method, Ψ)) - treatment_factors = Tuple(ConditionalDistribution(Ψ.scm, treatment, parents(Ψ.scm, treatment)) for treatment in treatments(Ψ)) - return CMRelevantFactors(Ψ.scm, outcome_model, treatment_factors) +function get_relevant_factors(Ψ::CMCompositeEstimand) + outcome_model = ExpectedValue(Ψ.outcome, Tuple(union(Ψ.outcome_extra_covariates, keys(Ψ.treatment_confounders), (Ψ.treatment_confounders)...))) + treatment_factors = Tuple(ConditionalDistribution(T, Ψ.treatment_confounders[T]) for T in treatments(Ψ)) + return CMRelevantFactors(outcome_model, treatment_factors) end function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: CMCompositeEstimand @@ -203,23 +94,7 @@ function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: CMCompositeEsti name(T), "\n-----", "\nOutcome: ", Ψ.outcome, - "\nTreatment: ", Ψ.treatment + "\nTreatment: ", Ψ.treatment_values ) println(io, param_string) -end - -outcome_equation(Ψ::CMCompositeEstimand) = Ψ.scm[outcome(Ψ)] - -get_outcome_model(Ψ::CMCompositeEstimand) = outcome_equation(Ψ).mach - -get_outcome_datas(Ψ::CMCompositeEstimand) = get_outcome_model(Ψ).data - -outcome(Ψ::CMCompositeEstimand) = Ψ.outcome -outcome(dataset, Ψ::CMCompositeEstimand) = Tables.getcolumn(dataset, outcome(Ψ)) - -function estimand_key(Ψ::CMCompositeEstimand) - return ( - join(treatments(Ψ), "_"), - string(outcome(Ψ)), - ) -end +end \ No newline at end of file diff --git a/test/counterfactual_mean_based/estimands.jl b/test/counterfactual_mean_based/estimands.jl index 0e6b103c..6305247f 100644 --- a/test/counterfactual_mean_based/estimands.jl +++ b/test/counterfactual_mean_based/estimands.jl @@ -8,257 +8,8 @@ using StableRNGs using MLJBase using MLJGLMInterface -@testset "Test various methods" begin - n = 5 - dataset = ( - Y = rand(n), - T = categorical([0, 1, 1, 0, 0]), - W₁ = rand(n), - W₂ = rand(n) - ) - # CM - scm = StaticConfoundedModel( - :Y, :T, [:W₁, :W₂], - covariates=:C₁ - ) - # Parameter validation - @test_throws TMLE.VariableNotAChildInSCMError("UndefinedOutcome") CM( - scm, - treatment=(T=1,), - outcome=:UndefinedOutcome - ) - @test_throws TMLE.VariableNotAChildInSCMError("UndefinedTreatment") CM( - scm, - treatment=(UndefinedTreatment=1,), - outcome=:Y - ) - @test_throws TMLE.TreatmentMustBeInOutcomeParentsError("Y") CM( - scm, - treatment=(Y=1,), - outcome=:T - ) - - Ψ = CM( - scm, - treatment=(T=1,), - outcome =:Y - ) - - @test TMLE.treatments(Ψ) == [:T] - @test TMLE.treatments(dataset, Ψ) == (T=dataset.T,) - @test TMLE.outcome(Ψ) == :Y - - # ATE - scm = SCM( - SE(:Z, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁]), - SE(:T₁, [:W₁₁, :W₁₂]), - SE(:T₂, [:W₂₁]), - ) - Ψ = ATE( - scm, - treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), - outcome =:Z, - ) - @test TMLE.treatments(Ψ) == [:T₁, :T₂] - @test TMLE.outcome(Ψ) == :Z - - # IATE - Ψ = IATE( - scm, - treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), - outcome =:Z, - ) - @test TMLE.treatments(Ψ) == [:T₁, :T₂] - @test TMLE.outcome(Ψ) == :Z -end - -@testset "Test constructors" begin - Ψ = CM(outcome=:Y, treatment=(T=1,), confounders=[:W]) - @test parents(Ψ.scm, :Y) == Set([:T, :W]) - - Ψ = ATE(outcome=:Y, treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), confounders=:W) - @test parents(Ψ.scm, :Y) == Set([:T₁, :T₂, :W]) - - Ψ = IATE(outcome=:Y, treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), confounders=:W) - @test parents(Ψ.scm, :Y) == Set([:T₁, :T₂, :W]) -end - - -@testset "Test fit!" begin - rng = StableRNG(123) - n = 100 - dataset = ( - Ycat = categorical(rand(rng, [0, 1], n)), - Ycont = rand(rng, n), - T₁ = categorical(rand(rng, [0, 1], n)), - T₂ = categorical(rand(rng, [0, 1], n)), - W₁₁ = rand(rng, n), - W₁₂ = rand(rng, n), - W₂₁ = rand(rng, n), - W₂₂ = rand(rng, n), - C = rand(rng, n) - ) - - scm = SCM( - SE(:Ycat, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C]), - SE(:Ycont, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C]), - SE(:T₁, [:W₁₁, :W₁₂]), - SE(:T₂, [:W₂₁, :W₂₂]), - ) - set_conditional_distribution!(scm, :Ycat, with_encoder(LinearBinaryClassifier())) - set_conditional_distribution!(scm, :Ycont, with_encoder(LinearRegressor())) - set_conditional_distribution!(scm, :T₁, LinearBinaryClassifier()) - set_conditional_distribution!(scm, :T₂, LinearBinaryClassifier()) - # Fits CM required factors - Ψ = CM( - scm, - treatment=(T₁=1,), - outcome=:Ycat - ) - log_sequence = ( - (:info, TMLE.UsingNaturalDistributionLog(scm, :Ycat, Set([:W₁₂, :W₁₁, :T₁, :C]))), - (:info, "Fitting Conditional Distribution Factor: Ycat | W₁₂, W₁₁, T₁, C"), - (:info, "Fitting Conditional Distribution Factor: T₁ | W₁₂, W₁₁"), - ) - @test_logs log_sequence... fit!(Ψ, dataset; adjustment_method=BackdoorAdjustment([:C]), verbosity=1) - outcome_dist = get_conditional_distribution(scm, :Ycat, Set([:W₁₂, :W₁₁, :T₁, :C])) - @test keys(fitted_params(outcome_dist.machine)) == (:linear_binary_classifier, :treatment_transformer) - @test keys(outcome_dist.machine.data[1]) == (:W₁₂, :W₁₁, :T₁, :C) - treatment_dist = get_conditional_distribution(scm, :T₁, Set([:W₁₁, :W₁₂])) - @test fitted_params(treatment_dist.machine).features == [:W₁₂, :W₁₁] - @test keys(treatment_dist.machine.data[1]) == (:W₁₂, :W₁₁) - # The natural outcome distribution has not been fitted for instance - natural_outcome_dist = get_conditional_distribution(scm, :Ycat, Set([:T₂, :W₁₂, :W₂₂, :W₂₁, :W₁₁, :T₁, :C])) - @test !isdefined(natural_outcome_dist, :machine) - - # Fits ATE required equations that haven't been fitted yet. - Ψ = ATE( - scm, - treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), - outcome =:Ycat, - ) - conditioning_set = Set([:W₁₂, :W₂₂, :W₂₁, :W₁₁, :T₁, :T₂]) - log_sequence = ( - (:info, TMLE.UsingNaturalDistributionLog(scm, :Ycat, conditioning_set)), - (:info, "Fitting Conditional Distribution Factor: Ycat | W₁₂, W₂₂, W₂₁, W₁₁, T₁, T₂"), - (:info, "Reusing or Updating Conditional Distribution Factor: T₁ | W₁₂, W₁₁"), - (:info, "Fitting Conditional Distribution Factor: T₂ | W₂₂, W₂₁"), - ) - @test_logs log_sequence... fit!(Ψ, dataset, verbosity=1) - - outcome_dist = get_conditional_distribution(scm, :Ycat, conditioning_set) - @test keys(fitted_params(outcome_dist.machine)) == (:linear_binary_classifier, :treatment_transformer) - @test Set(keys(outcome_dist.machine.data[1])) == conditioning_set - treatment_dist = get_conditional_distribution(scm, :T₂, Set([:W₂₂, :W₂₁])) - @test fitted_params(treatment_dist.machine).features == [:W₂₂, :W₂₁] - @test keys(treatment_dist.machine.data[1]) == (:W₂₂, :W₂₁) - - # Fits IATE required equations that haven't been fitted yet. - Ψ = IATE( - scm, - treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), - outcome =:Ycont, - ) - conditioning_set = Set([:W₁₂, :W₂₂, :W₂₁, :W₁₁, :T₁, :T₂]) - log_sequence = ( - (:info, TMLE.UsingNaturalDistributionLog(scm, :Ycont, conditioning_set)), - (:info, "Fitting Conditional Distribution Factor: Ycont | W₁₂, W₂₂, W₂₁, W₁₁, T₁, T₂"), - (:info, "Reusing or Updating Conditional Distribution Factor: T₁ | W₁₂, W₁₁"), - (:info, "Reusing or Updating Conditional Distribution Factor: T₂ | W₂₂, W₂₁"), - ) - @test_logs log_sequence... fit!(Ψ, dataset, verbosity=1) - outcome_dist = get_conditional_distribution(scm, :Ycont, conditioning_set) - @test keys(fitted_params(outcome_dist.machine)) == (:linear_regressor, :treatment_transformer) - @test Set(keys(outcome_dist.machine.data[1])) == conditioning_set - -end - -@testset "Test optimize_ordering" begin - rng = StableRNG(123) - scm = SCM( - SE(:T₁, [:W₁, :W₂]), - SE(:T₂, [:W₁, :W₂, :W₃]), - SE(:Y₁, [:T₁, :T₂, :W₁, :W₂, :C₁]), - SE(:Y₂, [:T₂, :W₁, :W₂, :W₃]), - ) - estimands = [ - ATE( - scm, - outcome=:Y₁, - treatment=(T₁=(case=1, control=0), T₂=(case="AC", control="CC")), - ), - IATE( - scm, - outcome=:Y₁, - treatment=(T₁=(case=1, control=0), T₂=(case="AA", control="CC")), - ), - ATE( - scm, - outcome=:Y₁, - treatment=(T₁=(case=1, control=0),), - ), - ATE( - scm, - outcome=:Y₂, - treatment=(T₂=(case="AC", control="CC"),), - ), - CM( - scm, - outcome=:Y₂, - treatment=(T₂="AC",), - ), - IATE( - scm, - outcome=:Y₁, - treatment=(T₁=(case=0, control=1), T₂=(case="AC", control="CC"),), - ), - CM( - scm, - outcome=:Y₂, - treatment=(T₂="CC",), - ), - ATE( - scm, - outcome=:Y₂, - treatment=(T₂=(case="AA", control="CC"),), - ), - ATE( - scm, - outcome=:Y₂, - treatment=(T₂=(case="AA", control="AC"),), - ), - ] - # Test estimand_key - @test TMLE.estimand_key(estimands[1]) == ("T₁_T₂", "Y₁") - @test TMLE.estimand_key(estimands[end]) == ("T₂", "Y₂") - # Non mutating function - estimands = shuffle(rng, estimands) - ordered_estimands = optimize_ordering(estimands) - expected_ordering = [ - # Y₁ - ATE(scm, outcome=:Y₁, treatment=(T₁ = (case = 1, control = 0),)), - ATE(scm, outcome=:Y₁, treatment=(T₁ = (case = 1, control = 0),T₂ = (case = "AC", control = "CC"))), - IATE(scm, outcome=:Y₁, treatment=(T₁ = (case = 0, control = 1), T₂ = (case = "AC", control = "CC"),)), - IATE(scm, outcome=:Y₁, treatment=(T₁ = (case = 1, control = 0),T₂ = (case = "AA", control = "CC"))), - # Y₂ - ATE(scm, outcome=:Y₂, treatment=(T₂ = (case = "AA", control = "AC"),)), - CM(scm, outcome=:Y₂, treatment=(T₂ = "CC",)), - ATE(scm, outcome=:Y₂, treatment=(T₂ = (case = "AA", control = "CC"),)), - CM(scm, outcome=:Y₂, treatment=(T₂ = "AC",)), - ATE(scm, outcome=:Y₂, treatment=(T₂ = (case = "AC", control = "CC"),)), - ] - @test ordered_estimands == expected_ordering - # Mutating function - optimize_ordering!(estimands) - @test estimands == ordered_estimands -end -@testset "Test indicator_fns & indicator_values" begin - scm = StaticConfoundedModel( - [:Y], - [:T₁, :T₂, :T₃], - [:W] - ) +@testset "Test CMCompositeEstimand" begin dataset = ( W = [1, 2, 3, 4, 5, 6, 7, 8], T₁ = ["A", "B", "A", "B", "A", "B", "A", "B"], @@ -268,32 +19,38 @@ end ) # Counterfactual Mean Ψ = CM( - scm, outcome=:Y, - treatment=(T₁="A", T₂=1), + treatment_values=(T₁="A", T₂=1), + treatment_confounders=(T₁=[:W], T₂=[:W]) ) indicator_fns = TMLE.indicator_fns(Ψ) @test indicator_fns == Dict(("A", 1) => 1.) - indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) + indic_values = TMLE.indicator_values(indicator_fns, TMLE.selectcols(dataset, TMLE.treatments(Ψ))) @test indic_values == [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0] # ATE Ψ = ATE( - scm, outcome=:Y, - treatment=(T₁=(case="A", control="B"), T₂=(control=0, case=1)), + treatment_values=( + T₁=(case="A", control="B"), + T₂=(control=0, case=1) + ), + treatment_confounders=(T₁=[:W], T₂=[:W]) ) indicator_fns = TMLE.indicator_fns(Ψ) @test indicator_fns == Dict( ("A", 1) => 1.0, ("B", 0) => -1.0 ) - indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) + indic_values = TMLE.indicator_values(indicator_fns, TMLE.selectcols(dataset, TMLE.treatments(Ψ))) @test indic_values == [0.0, -1.0, 1.0, 0.0, 0.0, -1.0, 1.0, 0.0] # 2-points IATE Ψ = IATE( - scm, outcome=:Y, - treatment=(T₁=(case="A", control="B"), T₂=(case=1, control=0)), + treatment_values=( + T₁=(case="A", control="B"), + T₂=(case=1, control=0) + ), + treatment_confounders=(T₁=[:W], T₂=[:W]) ) indicator_fns = TMLE.indicator_fns(Ψ) @test indicator_fns == Dict( @@ -302,13 +59,17 @@ end ("B", 1) => -1.0, ("B", 0) => 1.0 ) - indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) + indic_values = TMLE.indicator_values(indicator_fns, TMLE.selectcols(dataset, TMLE.treatments(Ψ))) @test indic_values == [-1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0] # 3-points IATE Ψ = IATE( - scm, outcome=:Y, - treatment=(T₁=(case="A", control="B"), T₂=(case=1, control=0), T₃=(control="D", case="C")), + treatment_values=( + T₁=(case="A", control="B"), + T₂=(case=1, control=0), + T₃=(control="D", case="C") + ), + treatment_confounders=(T₁=[:W], T₂=[:W], T₃=[:W]) ) indicator_fns = TMLE.indicator_fns(Ψ) @test indicator_fns == Dict( @@ -321,7 +82,7 @@ end ("B", 1, "D") => 1.0, ("A", 0, "C") => -1.0 ) - indic_values = TMLE.indicator_values(indicator_fns, TMLE.treatments(dataset, Ψ)) + indic_values = TMLE.indicator_values(indicator_fns, TMLE.selectcols(dataset, TMLE.treatments(Ψ))) @test indic_values == [-1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0] end @@ -329,8 +90,6 @@ end @test isconcretetype(ATE) @test isconcretetype(IATE) @test isconcretetype(CM) - @test isconcretetype(TMLE.TMLEstimate{Float64}) - @test isconcretetype(TMLE.OSEstimate{Float64}) end end From 62b86575dadc6496881fdb08f7abe7e9d446c22b Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sat, 23 Sep 2023 11:26:11 +0100 Subject: [PATCH 070/151] clean testfile --- test/counterfactual_mean_based/estimands.jl | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/test/counterfactual_mean_based/estimands.jl b/test/counterfactual_mean_based/estimands.jl index 6305247f..7aa0df7f 100644 --- a/test/counterfactual_mean_based/estimands.jl +++ b/test/counterfactual_mean_based/estimands.jl @@ -2,12 +2,6 @@ module TestEstimands using Test using TMLE -using CategoricalArrays -using Random -using StableRNGs -using MLJBase -using MLJGLMInterface - @testset "Test CMCompositeEstimand" begin dataset = ( @@ -87,9 +81,9 @@ using MLJGLMInterface end @testset "Test structs are concrete types" begin - @test isconcretetype(ATE) - @test isconcretetype(IATE) - @test isconcretetype(CM) + for type in Base.uniontypes(TMLE.CMCompositeEstimand) + @test isconcretetype(type) + end end end From deaad872af93a5b81809a68c9fd20be17c3553f1 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sat, 23 Sep 2023 17:35:59 +0100 Subject: [PATCH 071/151] more updates --- src/TMLE.jl | 2 +- ...t_and_covariate.jl => clever_covariate.jl} | 25 ++- src/counterfactual_mean_based/estimators.jl | 202 +++++++++++++----- src/counterfactual_mean_based/fluctuation.jl | 24 --- src/counterfactual_mean_based/gradient.jl | 2 +- src/estimates.jl | 15 ++ .../clever_covariate.jl | 134 ++++++++++++ .../non_regression_test.jl | 49 ++--- .../offset_and_covariate.jl | 141 ------------ test/estimators_and_estimates.jl | 47 +++- test/runtests.jl | 16 +- 11 files changed, 398 insertions(+), 259 deletions(-) rename src/counterfactual_mean_based/{offset_and_covariate.jl => clever_covariate.jl} (81%) create mode 100644 test/counterfactual_mean_based/clever_covariate.jl delete mode 100644 test/counterfactual_mean_based/offset_and_covariate.jl diff --git a/src/TMLE.jl b/src/TMLE.jl index 703b43ca..31277415 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -53,7 +53,7 @@ include("adjustment.jl") include("counterfactual_mean_based/estimands.jl") include("counterfactual_mean_based/fluctuation.jl") -include("counterfactual_mean_based/offset_and_covariate.jl") +include("counterfactual_mean_based/clever_covariate.jl") include("counterfactual_mean_based/gradient.jl") include("counterfactual_mean_based/estimators.jl") diff --git a/src/counterfactual_mean_based/offset_and_covariate.jl b/src/counterfactual_mean_based/clever_covariate.jl similarity index 81% rename from src/counterfactual_mean_based/offset_and_covariate.jl rename to src/counterfactual_mean_based/clever_covariate.jl index 43203ed3..7fbb2f25 100644 --- a/src/counterfactual_mean_based/offset_and_covariate.jl +++ b/src/counterfactual_mean_based/clever_covariate.jl @@ -18,15 +18,6 @@ function truncate!(v::AbstractVector, ps_lowerbound::AbstractFloat) end end -function compute_offset(ŷ::UnivariateFiniteVector{Multiclass{2}}) - μy = expected_value(ŷ) - logit!(μy) - return μy -end -compute_offset(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = expected_value(ŷ) -compute_offset(ŷ::AbstractVector{<:Real}) = expected_value(ŷ) - - function balancing_weights(Gs, dataset; ps_lowerbound=1e-8) jointlikelihood = ones(nrows(dataset)) for G ∈ Gs @@ -37,7 +28,13 @@ function balancing_weights(Gs, dataset; ps_lowerbound=1e-8) end """ - clever_covariate_and_weights(jointT, W, G, indicator_fns; ps_lowerbound=1e-8, weighted_fluctuation=false) + clever_covariate_and_weights( + Ψ::CMCompositeEstimand, + Gs::Tuple{Vararg{ConditionalDistributionEstimate}}, + dataset; + ps_lowerbound=1e-8, + weighted_fluctuation=false + ) Computes the clever covariate and weights that are used to fluctuate the initial Q. @@ -53,7 +50,13 @@ if `weighted_fluctuation = true`: where SpecialIndicator(t) is defined in `indicator_fns`. """ -function clever_covariate_and_weights(Ψ::CMCompositeEstimand, Gs, dataset; ps_lowerbound=1e-8, weighted_fluctuation=false) +function clever_covariate_and_weights( + Ψ::CMCompositeEstimand, + Gs::Tuple{Vararg{ConditionalDistributionEstimate}}, + dataset; + ps_lowerbound=1e-8, + weighted_fluctuation=false + ) # Compute the indicator values T = TMLE.selectcols(dataset, (p.estimand.outcome for p in Gs)) indic_vals = indicator_values(indicator_fns(Ψ), T) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 7417228b..3afcf327 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -1,64 +1,163 @@ -choose_initial_dataset(dataset, nomissing_dataset, resampling::Nothing) = dataset -choose_initial_dataset(dataset, nomissing_dataset, resampling) = nomissing_dataset - -function tmle!(Ψ::CMCompositeEstimand, models, dataset; - resampling=nothing, - adjustment_method=BackdoorAdjustment(), - verbosity=1, - force=false, - ps_lowerbound=1e-8, - weighted_fluctuation=false, - factors_cache=nothing +##################################################################### +### CMRelevantFactorsEstimator ### +##################################################################### + +struct CMRelevantFactorsEstimator <: Estimator + resampling + models +end + +get_train_validation_indices(resampling::Nothing, factors, dataset) = nothing + +function get_train_validation_indices(resampling::ResamplingStrategy, factors, dataset) + relevant_columns = collect(TMLE.variables(factors)) + outcome_variable = factors.outcome_mean.outcome + feature_variables = filter(x -> x !== outcome_variable, relevant_columns) + return collect(MLJBase.train_test_pairs( + resampling, + 1:nrows(dataset), + selectcols(dataset, feature_variables), + Tables.getcolumn(dataset, outcome_variable) + )) +end + +function (estimator::CMRelevantFactorsEstimator)(estimand, dataset; cache=Dict(), verbosity=1) + # Lookup in cache + if haskey(cache, (estimand, estimator)) + return from_cache(cache, estimand, estimator; verbosity=verbosity) + end + # Otherwise estimate + models = estimator.models + # Get train validation indices + train_validation_indices = TMLE.get_train_validation_indices(estimator.resampling, estimand, dataset) + # Fit propensity score + propensity_score_estimate = Tuple( + ConditionalDistributionEstimator(train_validation_indices, models[factor.outcome])( + factor, + dataset; + cache=cache, + verbosity=verbosity + ) + for factor in estimand.propensity_score ) + # Fit outcome mean + outcome_mean = estimand.outcome_mean + model = models[outcome_mean.outcome] + outcome_mean_estimate = TMLE.ConditionalDistributionEstimator( + train_validation_indices, + model + )(outcome_mean, dataset; cache=cache, verbosity=verbosity) + # Build estimate + estimate = MLCMRelevantFactors(estimand, outcome_mean_estimate, propensity_score_estimate) + # Update cache + cache[(estimand, estimator)] = estimate + + return estimate +end + +##################################################################### +### TargetedCMRelevantFactorsEstimator ### +##################################################################### + +struct TargetedCMRelevantFactorsEstimator + model +end + +TargetedCMRelevantFactorsEstimator(Ψ, initial_factors_estimate; tol=nothing, ps_lowerbound=1e-8, weighted=false) = + TargetedCMRelevantFactorsEstimator(TMLE.Fluctuation(Ψ, initial_factors_estimate; + tol=tol, + ps_lowerbound=ps_lowerbound, + weighted=weighted + )) + +function (estimator::TargetedCMRelevantFactorsEstimator)(estimand, dataset; cache=Dict(), verbosity=1) + model = estimator.model + # Fluctuate outcome model + fluctuated_outcome_mean = MLConditionalDistributionEstimator(model)( + model.initial_factors.outcome_mean.estimand, + dataset, + verbosity=verbosity + ) + # Do not fluctuate propensity score + fluctuated_propensity_score = model.initial_factors.propensity_score + # Build estimate + estimate = MLCMRelevantFactors(estimand, fluctuated_outcome_mean, fluctuated_propensity_score) + # Update cache + cache[:last_fluctuation] = estimate + + return estimate +end + +##################################################################### +### TMLE ### +##################################################################### +mutable struct TMLEE <: Estimator + models + resampling + ps_lowerbound + weighted + tol +end + +TMLEE(models; resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing) = + TMLEE(models, resampling, ps_lowerbound, weighted, tol) + +function (tmle::TMLEE)(Ψ, dataset; cache=Dict(), verbosity=1) # Check the estimand against the dataset TMLE.check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's relevant factors verbosity >= 1 && @info "Fitting the required equations..." - relevant_factors = TMLE.get_relevant_factors(Ψ; adjustment_method=adjustment_method) + relevant_factors = TMLE.get_relevant_factors(Ψ) nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(relevant_factors)) - initial_factors_dataset = TMLE.choose_initial_dataset(dataset, nomissing_dataset, resampling) - initial_factors_estimate = TMLE.estimate(relevant_factors, resampling, models, initial_factors_dataset; - factors_cache=factors_cache, - verbosity=verbosity - ) + initial_factors_dataset = TMLE.choose_initial_dataset(dataset, nomissing_dataset, tmle.resampling) + initial_factors_estimator = TMLE.CMRelevantFactorsEstimator(tmle.resampling, tmle.models) + initial_factors_estimate = initial_factors_estimator(relevant_factors, initial_factors_dataset; cache=cache, verbosity=verbosity) # Get propensity score truncation threshold n = nrows(nomissing_dataset) - ps_lowerbound = TMLE.ps_lower_bound(n, ps_lowerbound) - # Fit Fluctuation + ps_lowerbound = TMLE.ps_lower_bound(n, tmle.ps_lowerbound) + # Fluctuation initial factors verbosity >= 1 && @info "Performing TMLE..." - targeted_factors_estimate = TMLE.fluctuate(initial_factors_estimate, Ψ, nomissing_dataset; - tol=nothing, - verbosity=verbosity, - weighted_fluctuation=weighted_fluctuation, - ps_lowerbound=ps_lowerbound + targeted_factors_estimator = TMLE.TargetedCMRelevantFactorsEstimator( + Ψ, + initial_factors_estimate; + tol=tmle.tol, + ps_lowerbound=tmle.ps_lowerbound, + weighted=tmle.weighted ) + targeted_factors_estimate = targeted_factors_estimator(relevant_factors, nomissing_dataset; cache=cache, verbosity=verbosity) # Estimation results after TMLE IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, targeted_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) verbosity >= 1 && @info "Done." - return TMLEstimate(Ψ̂, IC), targeted_factors_estimate + # update!(cache, relevant_factors, targeted_factors_estimate) + return TMLEstimate(Ψ̂, IC), cache +end + +##################################################################### +### OSE ### +##################################################################### + +mutable struct OSE <: Estimator + models + resampling + ps_lowerbound end -function ose!(Ψ::CMCompositeEstimand, models, dataset; - adjustment_method=BackdoorAdjustment(), - verbosity=1, - force=false, - resampling=nothing, - ps_lowerbound=1e-8, - factors_cache=nothing) +OSE(models; resampling=nothing, ps_lowerbound=1e-8) = + OSE(models, resampling, ps_lowerbound) + +function (estimator::OSE)(Ψ::CMCompositeEstimand, dataset; verbosity=1) # Check the estimand against the dataset TMLE.check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's relevant factors verbosity >= 1 && @info "Fitting the required equations..." - relevant_factors = TMLE.get_relevant_factors(Ψ; adjustment_method=adjustment_method) - nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(relevant_factors)) - initial_factors_dataset = TMLE.choose_initial_dataset(dataset, nomissing_dataset, resampling) - initial_factors_estimate = TMLE.estimate(relevant_factors, resampling, models, initial_factors_dataset; - factors_cache=factors_cache, - verbosity=verbosity - ) + initial_factors = TMLE.get_relevant_factors(Ψ) + nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(initial_factors)) + initial_factors_dataset = TMLE.choose_initial_dataset(dataset, nomissing_dataset, estimator.resampling) + initial_factors_estimator = TMLE.CMRelevantFactorsEstimator(estimator.resampling, estimator.models) + initial_factors_estimate = initial_factors_estimator(initial_factors, initial_factors_dataset, verbosity=verbosity) # Get propensity score truncation threshold n = nrows(nomissing_dataset) - ps_lowerbound = TMLE.ps_lower_bound(n, ps_lowerbound) + ps_lowerbound = TMLE.ps_lower_bound(n, estimator.ps_lowerbound) # Gradient and estimate IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) @@ -66,20 +165,25 @@ function ose!(Ψ::CMCompositeEstimand, models, dataset; return OSEstimate(Ψ̂ + mean(IC), IC), initial_factors_estimate end -function naive_plugin_estimate!(Ψ::CMCompositeEstimand, models, dataset; - adjustment_method=BackdoorAdjustment(), - verbosity=1, - force=false, - factors_cache=nothing) +##################################################################### +### NAIVE ### +##################################################################### + +mutable struct NAIVE <: Estimator + model +end + +function (estimator::NAIVE)(Ψ::CMCompositeEstimand, dataset; verbosity=1) # Check the estimand against the dataset TMLE.check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's relevant factors verbosity >= 1 && @info "Fitting the required equations..." - relevant_factors = TMLE.get_relevant_factors(Ψ; adjustment_method=adjustment_method) + relevant_factors = TMLE.get_relevant_factors(Ψ) nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(relevant_factors)) - factors_estimate = TMLE.estimate(relevant_factors, nothing, models, dataset; - factors_cache=factors_cache, + outcome_mean_estimate = MLConditionalDistributionEstimator(estimator.model)( + relevant_factors.outcome_mean, + dataset; verbosity=verbosity ) - return mean(counterfactual_aggregate(Ψ, factors_estimate.outcome_mean, nomissing_dataset)) + return mean(counterfactual_aggregate(Ψ, outcome_mean_estimate, nomissing_dataset)) end \ No newline at end of file diff --git a/src/counterfactual_mean_based/fluctuation.jl b/src/counterfactual_mean_based/fluctuation.jl index 8384667a..176f2917 100644 --- a/src/counterfactual_mean_based/fluctuation.jl +++ b/src/counterfactual_mean_based/fluctuation.jl @@ -9,30 +9,6 @@ end Fluctuation(Ψ, initial_factors; tol=nothing, ps_lowerbound=1e-8, weighted=false) = Fluctuation(Ψ, initial_factors, tol, ps_lowerbound, weighted) -function fluctuate(initial_factors_estimate::MLCMRelevantFactors, Ψ, dataset; - tol=nothing, - verbosity=1, - weighted_fluctuation=false, - ps_lowerbound=1e-8, - factors_cache=nothing - ) - Qfluct_model = TMLE.Fluctuation(Ψ, initial_factors_estimate; - tol=tol, - ps_lowerbound=ps_lowerbound, - weighted=weighted_fluctuation - ) - fluctuated_outcome_mean = TMLE.estimate( - initial_factors_estimate.outcome_mean.estimand, - dataset, - Qfluct_model, - nothing; - factors_cache=factors_cache, - verbosity=verbosity - ) - fluctuated_propensity_score = initial_factors_estimate.propensity_score - return MLCMRelevantFactors(initial_factors_estimate.estimand, fluctuated_outcome_mean, fluctuated_propensity_score) -end - one_dimensional_path(target_scitype::Type{T}) where T <: AbstractVector{<:MLJBase.Continuous} = LinearRegressor(fit_intercept=false, offsetcol = :offset) one_dimensional_path(target_scitype::Type{T}) where T <: AbstractVector{<:Finite} = LinearBinaryClassifier(fit_intercept=false, offsetcol = :offset) diff --git a/src/counterfactual_mean_based/gradient.jl b/src/counterfactual_mean_based/gradient.jl index 9d1fbac9..89c2009c 100644 --- a/src/counterfactual_mean_based/gradient.jl +++ b/src/counterfactual_mean_based/gradient.jl @@ -9,7 +9,7 @@ For the ATE with binary treatment, confounded by W it is equal to: """ function counterfactual_aggregate(Ψ::CMCompositeEstimand, Q, dataset) - X = selectcols(dataset, featurenames(Q)) + X = selectcols(dataset, Q.estimand.parents) Ttemplate = selectcols(X, treatments(Ψ)) n = nrows(Ttemplate) ctf_agg = zeros(n) diff --git a/src/estimates.jl b/src/estimates.jl index 0056424a..3bb5b743 100644 --- a/src/estimates.jl +++ b/src/estimates.jl @@ -76,6 +76,21 @@ function likelihood(estimate::ConditionalDistributionEstimate, dataset) return pdf.(ŷ, y) end +function compute_offset(ŷ::UnivariateFiniteVector{Multiclass{2}}) + μy = expected_value(ŷ) + logit!(μy) + return μy +end + +compute_offset(ŷ::AbstractVector{<:Distributions.UnivariateDistribution}) = expected_value(ŷ) + +compute_offset(ŷ::AbstractVector{<:Real}) = expected_value(ŷ) + +function compute_offset(estimate::ConditionalDistributionEstimate, X) + ŷ = predict(estimate, X) + return compute_offset(ŷ) +end + ##################################################################### ### MLCMRelevantFactors ### ##################################################################### diff --git a/test/counterfactual_mean_based/clever_covariate.jl b/test/counterfactual_mean_based/clever_covariate.jl new file mode 100644 index 00000000..2b7e131d --- /dev/null +++ b/test/counterfactual_mean_based/clever_covariate.jl @@ -0,0 +1,134 @@ +module TestOffsetAndCovariate + +using Test +using TMLE +using CategoricalArrays +using MLJModels + +@testset "Test ps_lower_bound" begin + n = 7 + max_lb = 0.1 + # Nothing results in data adaptive lower bound no lower than max_lb + @test TMLE.ps_lower_bound(n, nothing) == max_lb + @test TMLE.data_adaptive_ps_lower_bound(n) == max_lb + @test TMLE.data_adaptive_ps_lower_bound(1000) == 0.02984228238321508 + # Otherwise use the provided threhsold provided it's lower than max_lb + @test TMLE.ps_lower_bound(n, 1e-8) == 1.0e-8 + @test TMLE.ps_lower_bound(n, 1) == 0.1 +end + +@testset "Test clever_covariate_and_weights: 1 treatment" begin + Ψ = ATE( + outcome=:Y, + treatment_values=(T=(case="a", control="b"),), + treatment_confounders=(T=[:W],), + ) + dataset = ( + T = categorical(["a", "b", "c", "a", "a", "b", "a"]), + Y = [1., 2., 3, 4, 5, 6, 7], + W = rand(7), + ) + distr_estimate = TMLE.MLConditionalDistributionEstimator(ConstantClassifier())( + TMLE.ConditionalDistribution(:T, [:W]), + dataset, + verbosity=0 + ) + weighted_fluctuation = true + ps_lowerbound = 1e-8 + cov, w = TMLE.clever_covariate_and_weights(Ψ, (distr_estimate,), dataset; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) + + @test cov == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] + @test w == [1.75, 3.5, 7.0, 1.75, 1.75, 3.5, 1.75] + + weighted_fluctuation = false + cov, w = TMLE.clever_covariate_and_weights(Ψ, (distr_estimate,), dataset; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) + @test cov == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] + @test w == ones(7) +end + +@testset "Test clever_covariate_and_weights: 2 treatments" begin + Ψ = IATE( + outcome = :Y, + treatment_values=( + T₁=(case=1, control=0), + T₂=(case=1, control=0) + ), + treatment_confounders=(T₁=[:W], T₂=[:W]) + ) + dataset = ( + T₁ = categorical([1, 0, 0, 1, 1, 1, 0]), + T₂ = categorical([1, 1, 1, 1, 1, 0, 0]), + Y = categorical([1, 1, 1, 1, 0, 0, 0]), + W = rand(7) + ) + T₁_distr_estimate = TMLE.MLConditionalDistributionEstimator(ConstantClassifier())( + TMLE.ConditionalDistribution(:T₁, [:W]), + dataset, + verbosity=0 + ) + T₂_distr_estimate = TMLE.MLConditionalDistributionEstimator(ConstantClassifier())( + TMLE.ConditionalDistribution(:T₂, [:W]), + dataset, + verbosity=0 + ) + ps_lowerbound = 1e-8 + weighted_fluctuation = false + + cov, w = TMLE.clever_covariate_and_weights(Ψ, (T₁_distr_estimate, T₂_distr_estimate), dataset; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) + @test cov ≈ [2.45, -3.266, -3.266, 2.45, 2.45, -6.125, 8.166] atol=1e-2 + @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] +end + +@testset "Test compute_offset, clever_covariate_and_weights: 3 treatments" begin + ## Third case: 3 Treatment variables + Ψ = IATE( + outcome =:Y, + treatment_values=( + T₁=(case="a", control="b"), + T₂=(case=1, control=2), + T₃=(case=true, control=false) + ), + treatment_confounders=( + T₁=[:W], + T₂=[:W], + T₃=[:W] + ) + ) + dataset = ( + T₁ = categorical(["a", "a", "b", "b", "c", "b", "b"]), + T₂ = categorical([3, 2, 1, 1, 2, 2, 2], ordered=true), + T₃ = categorical([true, false, true, false, false, false, false], ordered=true), + Y = [1., 2., 3, 4, 5, 6, 7], + W = rand(7), + ) + ps_lowerbound = 1e-8 + weighted_fluctuation = false + + cond_distr = Tuple( + TMLE.MLConditionalDistributionEstimator(ConstantClassifier())( + TMLE.ConditionalDistribution(T, [:W]), + dataset, + verbosity=0 + ) for T in (:T₁, :T₂, :T₃) + ) + + cov, w = TMLE.clever_covariate_and_weights(Ψ, cond_distr, dataset; + ps_lowerbound=ps_lowerbound, + weighted_fluctuation=weighted_fluctuation + ) + @test cov ≈ [0, 8.575, -21.4375, 8.575, 0, -4.2875, -4.2875] atol=1e-3 + @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] +end + +end + +true \ No newline at end of file diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index c6db7eda..be4a69fd 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -8,6 +8,7 @@ using CategoricalArrays using MLJGLMInterface using MLJBase + @testset "Test ATE on perinatal dataset." begin # This is a non-regression test which was checked against the R tmle3 package dataset = CSV.read(joinpath("data", "perinatal.csv"), DataFrame, missingstring=["", "NA"]) @@ -20,27 +21,31 @@ using MLJBase Ψ = ATE( outcome=:haz01, - treatment=(parity01=(case=1, control=0),), - confounders=confounders + treatment_values=(parity01=(case=1, control=0),), + treatment_confounders=(parity01=confounders,) ) models = ( haz01 = with_encoder(LinearBinaryClassifier()), parity01 = LinearBinaryClassifier() ) - + resampling=nothing # No CV ps_lowerbound = 0.025 # Cutoff hardcoded in tmle3 - resampling = nothing # Vanilla TMLE - weighted_fluctuation = false # Unweighted fluctuation + weighted = false # Unweighted fluctuation verbosity = 1 # No logs - adjustment_method = BackdoorAdjustment() - factors_cache = nothing - tmle_result, targeted_factors = tmle!(Ψ, models, dataset; + tmle = TMLEE(models; resampling=resampling, - weighted_fluctuation=weighted_fluctuation, - ps_lowerbound=ps_lowerbound, - adjustment_method=adjustment_method, - verbosity=verbosity) + ps_lowerbound=ps_lowerbound, + weighted=weighted + ) + + tmle_result, cache = tmle(Ψ, dataset; verbosity=verbosity); tmle_result + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=verbosity); + cache + tmle.models = ( + haz01 = with_encoder(LinearBinaryClassifier()), + parity01 = LinearBinaryClassifier(fit_intercept=false) + ) @test estimate(tmle_result) ≈ -0.185533 atol = 1e-6 l, u = confint(OneSampleTTest(tmle_result)) @test l ≈ -0.279246 atol = 1e-6 @@ -48,25 +53,21 @@ using MLJBase @test OneSampleZTest(tmle_result) isa OneSampleZTest # OSE - ose_result, targeted_factors = ose!(Ψ, models, dataset; + ose = OSE(models; resampling=resampling, - ps_lowerbound=ps_lowerbound, - adjustment_method=adjustment_method, - verbosity=verbosity) + ps_lowerbound=ps_lowerbound + ) + ose_result, targeted_factors = ose(Ψ, dataset; verbosity=verbosity) ose_result # CV-TMLE - resampling = StratifiedCV(nfolds=10) - cv_tmle_result, targeted_factors = tmle!(Ψ, models, dataset; - resampling=resampling, - weighted_fluctuation=weighted_fluctuation, - ps_lowerbound=ps_lowerbound, - adjustment_method=adjustment_method, - verbosity=verbosity) + tmle.resampling = StratifiedCV(nfolds=10) + cv_tmle_result, targeted_factors = tmle(Ψ, dataset; verbosity=verbosity) cv_tmle_result # Naive - @test naive_plugin_estimate!(Ψ, models, dataset) ≈ -0.150078 atol = 1e-6 + naive = NAIVE(models.haz01) + @test naive(Ψ, dataset) ≈ -0.150078 atol = 1e-6 end diff --git a/test/counterfactual_mean_based/offset_and_covariate.jl b/test/counterfactual_mean_based/offset_and_covariate.jl deleted file mode 100644 index e489d360..00000000 --- a/test/counterfactual_mean_based/offset_and_covariate.jl +++ /dev/null @@ -1,141 +0,0 @@ -module TestOffsetAndCovariate - -using Test -using TMLE -using CategoricalArrays -using MLJModels -using Statistics - -@testset "Test ps_lower_bound" begin - dataset = ( - T = categorical([1, 0, 1, 1, 0, 1, 1]), - Y = [1., 2., 3, 4, 5, 6, 7], - W = rand(7), - ) - Ψ = ATE( - outcome=:Y, - treatment=(T=(case=1, control=0),), - confounders=:W, - ) - max_lb = 0.1 - fit!(Ψ, dataset, verbosity=0) - # Nothing results in data adaptive lower bound no lower than max_lb - @test TMLE.ps_lower_bound(Ψ, nothing) == max_lb - @test TMLE.data_adaptive_ps_lower_bound(Ψ) == max_lb - @test TMLE.data_adaptive_ps_lower_bound(1000) == 0.02984228238321508 - # Otherwise use the provided threhsold provided it's lower than max_lb - @test TMLE.ps_lower_bound(Ψ, 1e-8) == 1.0e-8 - @test TMLE.ps_lower_bound(Ψ, 1) == 0.1 -end - -@testset "Test compute_offset, clever_covariate_and_weights: 1 treatment" begin - ### In all cases, models are "constant" models - ## First case: 1 Treamtent variable - Ψ = ATE( - outcome=:Y, - treatment=(T=(case="a", control="b"),), - confounders=[:W₁, :W₂, :W₃], - treatment_model = ConstantClassifier(), - outcome_model = ConstantRegressor() - ) - dataset = ( - T = categorical(["a", "b", "c", "a", "a", "b", "a"]), - Y = [1., 2., 3, 4, 5, 6, 7], - W₁ = rand(7), - W₂ = rand(7), - W₃ = rand(7) - ) - fit!(Ψ, dataset, verbosity=0) - - offset = TMLE.compute_offset(Ψ) - @test offset == repeat([mean(dataset.Y)], 7) - weighted_fluctuation = true - ps_lowerbound = 1e-8 - X, y = TMLE.get_outcome_datas(Ψ) - cov, w = TMLE.clever_covariate_and_weights(Ψ, X; - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation - ) - - @test cov == [1.0, -1.0, 0.0, 1.0, 1.0, -1.0, 1.0] - @test w == [1.75, 3.5, 7.0, 1.75, 1.75, 3.5, 1.75] - - weighted_fluctuation = false - cov, w = TMLE.clever_covariate_and_weights(Ψ, X; - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation - ) - @test cov == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] - @test w == ones(7) -end - -@testset "Test compute_offset, clever_covariate_and_weights: 2 treatments" begin - Ψ = IATE( - outcome = :Y, - treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), - confounders=[:W₁, :W₂, :W₃], - treatment_model = ConstantClassifier(), - outcome_model = ConstantClassifier() - ) - dataset = ( - T₁ = categorical([1, 0, 0, 1, 1, 1, 0]), - T₂ = categorical([1, 1, 1, 1, 1, 0, 0]), - Y = categorical([1, 1, 1, 1, 0, 0, 0]), - W₁ = rand(7), - W₂ = rand(7), - W₃ = rand(7) - ) - ps_lowerbound = 1e-8 - weighted_fluctuation = false - - fit!(Ψ, dataset, verbosity=0) - - # Because the outcome is binary, the offset is the logit - offset = TMLE.compute_offset(Ψ) - @test offset == repeat([0.28768207245178085], 7) - X, y = TMLE.get_outcome_datas(Ψ) - cov, w = TMLE.clever_covariate_and_weights(Ψ, X; - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation - ) - @test cov ≈ [2.45, -3.266, -3.266, 2.45, 2.45, -6.125, 8.166] atol=1e-2 - @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] -end - -@testset "Test compute_offset, clever_covariate_and_weights: 3 treatments" begin - ## Third case: 3 Treatment variables - Ψ = IATE( - outcome =:Y, - treatment=(T₁=(case="a", control="b"), - T₂=(case=1, control=2), - T₃=(case=true, control=false)), - confounders=[:W], - treatment_model = ConstantClassifier(), - outcome_model = DeterministicConstantRegressor() - ) - dataset = ( - T₁ = categorical(["a", "a", "b", "b", "c", "b", "b"]), - T₂ = categorical([3, 2, 1, 1, 2, 2, 2], ordered=true), - T₃ = categorical([true, false, true, false, false, false, false], ordered=true), - Y = [1., 2., 3, 4, 5, 6, 7], - W = rand(7), - ) - ps_lowerbound = 1e-8 - weighted_fluctuation = false - - fit!(Ψ, dataset, verbosity=0) - - offset = TMLE.compute_offset(Ψ) - @test offset == repeat([4.0], 7) - X, y = TMLE.get_outcome_datas(Ψ) - cov, w = TMLE.clever_covariate_and_weights(Ψ, X; - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation - ) - @test cov ≈ [0, 8.575, -21.4375, 8.575, 0, -4.2875, -4.2875] atol=1e-3 - @test w == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] -end - -end - -true \ No newline at end of file diff --git a/test/estimators_and_estimates.jl b/test/estimators_and_estimates.jl index 7ad010ab..9244b33f 100644 --- a/test/estimators_and_estimates.jl +++ b/test/estimators_and_estimates.jl @@ -5,6 +5,8 @@ using TMLE using MLJBase using DataFrames using MLJGLMInterface +using MLJModels +using LogExpFunctions verbosity = 1 n = 100 @@ -26,7 +28,7 @@ reuse_log = string("Reusing estimate for: ", TMLE.string_repr(estimand)) @test ŷ == predict(estimate.machine, dataset[!, expected_features]) μ̂ = TMLE.expected_value(estimate, dataset) @test μ̂ == mean.(ŷ) - @test all(0. <= x <= 1. for x in TMLE.likelihood(estimate, dataset)) + @test_skip all(0. <= x <= 1. for x in TMLE.likelihood(estimate, dataset)) # The pdf is not necessarily between 0 and 1 # Uses the cache instead of fitting new_estimator = TMLE.MLConditionalDistributionEstimator(LinearRegressor()) @test TMLE.key(new_estimator) == TMLE.key(estimator) @@ -61,7 +63,7 @@ end @test ŷ[val] == ŷfold @test μ̂[val] == mean.(ŷfold) end - all(0. <= x <= 1. for x in TMLE.likelihood(estimate, dataset)) + @test_skip all(0. <= x <= 1. for x in TMLE.likelihood(estimate, dataset)) # Uses the cache instead of fitting new_estimator = TMLE.SampleSplitMLConditionalDistributionEstimator( LinearRegressor(), @@ -86,6 +88,47 @@ end @test length(cache) == 3 end +@testset "Test compute_offset MLConditionalDistributionEstimator" begin + dataset = ( + T = categorical(["a", "b", "c", "a", "a", "b", "a"]), + Ycont = [1., 2., 3, 4, 5, 6, 7], + Ycat = categorical([1, 0, 0, 1, 1, 1, 0]), + W = rand(7), + ) + μYcont = mean(dataset.Ycont) + μYcat = mean(float(dataset.Ycat)) + # The model is probabilistic continuous, the offset is the mean of + # the conditional distribution + distr_estimate = TMLE.MLConditionalDistributionEstimator(ConstantRegressor())( + TMLE.ConditionalDistribution(:Ycont, [:W, :T]), + dataset, + verbosity=0 + ) + offset = TMLE.compute_offset(distr_estimate, dataset) + @test offset == mean.(predict(distr_estimate, dataset)) + @test offset == repeat([μYcont], 7) + # The model is deterministic, the offset is simply the output + # of the predict function which is assumed to correspond to the mean + # if the squared loss was optimized for by the underlying model + distr_estimate = TMLE.MLConditionalDistributionEstimator(DeterministicConstantRegressor())( + TMLE.ConditionalDistribution(:Ycont, [:W, :T]), + dataset, + verbosity=0 + ) + offset = TMLE.compute_offset(distr_estimate, dataset) + @test offset == predict(distr_estimate, dataset) + @test offset == repeat([μYcont], 7) + # The model is probabilistic binary, the offset is the logit + # of the mean of the conditional distribution + distr_estimate = TMLE.MLConditionalDistributionEstimator(ConstantClassifier())( + TMLE.ConditionalDistribution(:Ycat, [:W, :T]), + dataset, + verbosity=0 + ) + offset = TMLE.compute_offset(distr_estimate, dataset) + @test offset == repeat([logit(μYcat)], 7) +end + end true \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index afe47204..02c23051 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,19 +1,23 @@ using Test @time begin - @test include("distribution_factors.jl") - @test include("scm.jl") @test include("utils.jl") + @test include("estimands.jl") + @test include("estimators_and_estimates.jl") @test include("missing_management.jl") @test include("composition.jl") @test include("treatment_transformer.jl") - @test include("counterfactual_mean_based/non_regression_test.jl") - @test include("counterfactual_mean_based/offset_and_covariate.jl") + # @test include("scm.jl") + # @test include("adjustment.jl") + + @test include("counterfactual_mean_based/estimands.jl") + @test include("counterfactual_mean_based/clever_covariate.jl") @test include("counterfactual_mean_based/gradient.jl") @test include("counterfactual_mean_based/fluctuation.jl") + @test include("counterfactual_mean_based/estimators_and_estimates.jl") + @test include("counterfactual_mean_based/non_regression_test.jl") @test include("counterfactual_mean_based/double_robustness_ate.jl") @test include("counterfactual_mean_based/double_robustness_iate.jl") @test include("counterfactual_mean_based/3points_interactions.jl") - @test include("counterfactual_mean_based/estimators.jl") - @test include("counterfactual_mean_based/estimands.jl") + end \ No newline at end of file From 10ccef06612f882c6792bf46eb9c3adb502f0e94 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sat, 23 Sep 2023 17:57:54 +0100 Subject: [PATCH 072/151] update gradient --- src/TMLE.jl | 4 +- src/counterfactual_mean_based/estimands.jl | 35 ++++ src/counterfactual_mean_based/estimates.jl | 189 ++++++++++++++++++++ src/estimands.jl | 33 +--- src/estimates.jl | 190 --------------------- test/counterfactual_mean_based/gradient.jl | 35 ++-- 6 files changed, 248 insertions(+), 238 deletions(-) create mode 100644 src/counterfactual_mean_based/estimates.jl diff --git a/src/TMLE.jl b/src/TMLE.jl index 31277415..3ca9b54f 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -52,10 +52,12 @@ include("treatment_transformer.jl") include("adjustment.jl") include("counterfactual_mean_based/estimands.jl") +include("counterfactual_mean_based/estimators.jl") +include("counterfactual_mean_based/estimates.jl") include("counterfactual_mean_based/fluctuation.jl") include("counterfactual_mean_based/clever_covariate.jl") include("counterfactual_mean_based/gradient.jl") -include("counterfactual_mean_based/estimators.jl") + # ############################################################################# # PRECOMPILATION WORKLOAD diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index d0819a9d..122e5605 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -1,3 +1,38 @@ +##################################################################### +### CMRelevantFactors ### +##################################################################### + +""" +Defines relevant factors that need to be estimated in order to estimate any +Counterfactual Mean composite estimand (see `CMCompositeEstimand`). +""" +struct CMRelevantFactors <: Estimand + outcome_mean::ConditionalDistribution + propensity_score::Tuple{Vararg{ConditionalDistribution}} +end + +CMRelevantFactors(outcome_mean, propensity_score::ConditionalDistribution) = + CMRelevantFactors(outcome_mean, (propensity_score,)) + +CMRelevantFactors(outcome_mean, propensity_score) = + CMRelevantFactors(outcome_mean, propensity_score) + +CMRelevantFactors(;outcome_mean, propensity_score) = + CMRelevantFactors(outcome_mean, propensity_score) + +string_repr(estimand::CMRelevantFactors) = + string("Composite Factor: \n", + "----------------\n- ", + string_repr(estimand.outcome_mean),"\n- ", + join((string_repr(f) for f in estimand.propensity_score), "\n- ")) + +variables(estimand::CMRelevantFactors) = + Tuple(union(variables(estimand.outcome_mean), (variables(est) for est in estimand.propensity_score)...)) + +##################################################################### +### Functionals ### +##################################################################### + ESTIMANDS_DOCS = Dict( :CM => (formula="``CM(Y, T=t) = E[Y|do(T=t)]``",), :ATE => (formula="``ATE(Y, T, case, control) = E[Y|do(T=case)] - E[Y|do(T=control)``",), diff --git a/src/counterfactual_mean_based/estimates.jl b/src/counterfactual_mean_based/estimates.jl new file mode 100644 index 00000000..c759a96c --- /dev/null +++ b/src/counterfactual_mean_based/estimates.jl @@ -0,0 +1,189 @@ +##################################################################### +### MLCMRelevantFactors ### +##################################################################### + +""" +Holds a Sample Split Machine Learning set of estimates (outcome mean, propensity score) +for counterfactual mean based estimands' relevant factors. +""" +struct MLCMRelevantFactors <: Estimate + estimand::CMRelevantFactors + outcome_mean::ConditionalDistributionEstimate + propensity_score::Tuple{Vararg{ConditionalDistributionEstimate}} +end + +string_repr(estimate::MLCMRelevantFactors) = string( + "Composite Factor Estimate: \n", + "-------------------------\n- ", + string_repr(estimate.outcome_mean),"\n- ", + join((string_repr(f) for f in estimate.propensity_score), "\n- ") +) + + +##################################################################### +### One Dimensional Estimates ### +##################################################################### + + +struct TMLEstimate{T<:AbstractFloat} <: Estimate + Ψ̂::T + IC::Vector{T} +end + +struct OSEstimate{T<:AbstractFloat} <: Estimate + Ψ̂::T + IC::Vector{T} +end + +const EICEstimate = Union{TMLEstimate, OSEstimate} + +function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate) + testresult = OneSampleTTest(est) + data = [estimate(est) confint(testresult) pvalue(testresult);] + pretty_table(io, data;header=["Estimate", "95% Confidence Interval", "P-value"]) +end + +struct ComposedEstimate{T<:AbstractFloat} <: Estimate + Ψ̂::Array{T} + σ̂::Matrix{T} + n::Int +end + +function Base.show(io::IO, ::MIME"text/plain", est::ComposedEstimate) + if length(est.σ̂) !== 1 + println(io, string("Estimate: ", estimate(est), "\nVariance: \n", var(est))) + else + testresult = OneSampleTTest(est) + data = [estimate(est) confint(testresult) pvalue(testresult);] + headers = ["Estimate", "95% Confidence Interval", "P-value"] + pretty_table(io, data;header=headers) + end +end + +""" + Distributions.estimate(r::EICEstimate) + +Retrieves the final estimate: after the TMLE step. +""" +Distributions.estimate(est::EICEstimate) = est.Ψ̂ + +""" + Distributions.estimate(r::ComposedEstimate) + +Retrieves the final estimate: after the TMLE step. +""" +Distributions.estimate(est::ComposedEstimate) = + length(est.Ψ̂) == 1 ? est.Ψ̂[1] : est.Ψ̂ + +""" + var(r::EICEstimate) + +Computes the estimated variance associated with the estimate. +""" +Statistics.var(est::EICEstimate) = + var(est.IC)/size(est.IC, 1) + +""" + var(r::ComposedEstimate) + +Computes the estimated variance associated with the estimate. +""" +Statistics.var(est::ComposedEstimate) = + length(est.σ̂) == 1 ? est.σ̂[1] / est.n : est.σ̂ ./ est.n + + +""" + OneSampleZTest(r::EICEstimate, Ψ₀=0) + +Performs a Z test on the EICEstimate. +""" +HypothesisTests.OneSampleZTest(est::EICEstimate, Ψ₀=0) = + OneSampleZTest(estimate(est), std(est.IC), size(est.IC, 1), Ψ₀) + +""" + OneSampleTTest(r::EICEstimate, Ψ₀=0) + +Performs a T test on the EICEstimate. +""" +HypothesisTests.OneSampleTTest(est::EICEstimate, Ψ₀=0) = + OneSampleTTest(estimate(est), std(est.IC), size(est.IC, 1), Ψ₀) + +""" + OneSampleTTest(r::ComposedEstimate, Ψ₀=0) + +Performs a T test on the ComposedEstimate. +""" +function HypothesisTests.OneSampleTTest(est::ComposedEstimate, Ψ₀=0) + @assert length(est.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." + return OneSampleTTest(estimate(est), sqrt(est.σ̂[1]), est.n, Ψ₀) +end + +""" + OneSampleZTest(r::ComposedEstimate, Ψ₀=0) + +Performs a T test on the ComposedEstimate. +""" +function HypothesisTests.OneSampleZTest(est::ComposedEstimate, Ψ₀=0) + @assert length(est.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." + return OneSampleZTest(estimate(est), sqrt(est.σ̂[1]), est.n, Ψ₀) +end + +""" + compose(f, estimation_results::Vararg{EICEstimate, N}) where N + +Provides an estimator of f(estimation_results...). + +# Mathematical details + +The following is a summary from `Asymptotic Statistics`, A. W. van der Vaart. + +Consider k TMLEs computed from a dataset of size n and embodied by Tₙ = (T₁,ₙ, ..., Tₖ,ₙ). +Since each of them is asymptotically normal, the multivariate CLT provides the joint +distribution: + + √n(Tₙ - Ψ₀) ↝ N(0, Σ), + +where Σ is the covariance matrix of the TMLEs influence curves. + +Let f:ℜᵏ→ℜᵐ, be a differentiable map at Ψ₀. Then, the delta method provides the +limiting distribution of √n(f(Tₙ) - f(Ψ₀)). Because Tₙ is normal, the result is: + + √n(f(Tₙ) - f(Ψ₀)) ↝ N(0, ∇f(Ψ₀) ̇Σ ̇(∇f(Ψ₀))ᵀ), + +where ∇f(Ψ₀):ℜᵏ→ℜᵐ is a linear map such that by abusing notations and identifying the +function with the multiplication matrix: ∇f(Ψ₀):h ↦ ∇f(Ψ₀) ̇h. And the matrix ∇f(Ψ₀) is +the jacobian of f at Ψ₀. + +Hence, the only thing we need to do is: +- Compute the covariance matrix Σ +- Compute the jacobian ∇f, which can be done using Julia's automatic differentiation facilities. +- The final estimator is normal with mean f₀=f(Ψ₀) and variance σ₀=∇f(Ψ₀) ̇Σ ̇(∇f(Ψ₀))ᵀ + +# Arguments + +- f: An array-input differentiable map. +- estimation_results: 1 or more `EICEstimate` structs. + +# Examples + +Assuming `res₁` and `res₂` are TMLEs: + +```julia +f(x, y) = [x^2 - y, y - 3x] +compose(f, res₁, res₂) +``` +""" +function compose(f, estimators::Vararg{EICEstimate, N}; backend=AD.ZygoteBackend()) where N + Σ = cov(estimators...) + estimates = [estimate(r) for r in estimators] + f₀, Js = AD.value_and_jacobian(backend, f, estimates...) + J = hcat(Js...) + n = size(first(estimators).IC, 1) + σ₀ = J*Σ*J' + return ComposedEstimate(collect(f₀), σ₀, n) +end + +function Statistics.cov(estimators::Vararg{EICEstimate, N}) where N + X = hcat([r.IC for r in estimators]...) + return Statistics.cov(X, dims=1, corrected=true) +end diff --git a/src/estimands.jl b/src/estimands.jl index a7ffdb68..d486ef03 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -115,35 +115,4 @@ Defines an Expected Value estimand ``parents → E[Outcome|parents]``. At the moment there is no distinction between an Expected Value and a Conditional Distribution because they are estimated in the same way. """ -const ExpectedValue = ConditionalDistribution - -##################################################################### -### CMRelevantFactors ### -##################################################################### - -""" -Defines relevant factors that need to be estimated in order to estimate any -Counterfactual Mean composite estimand (see `CMCompositeEstimand`). -""" -struct CMRelevantFactors <: Estimand - outcome_mean::ConditionalDistribution - propensity_score::Tuple{Vararg{ConditionalDistribution}} -end - -CMRelevantFactors(outcome_mean, propensity_score::ConditionalDistribution) = - CMRelevantFactors(outcome_mean, (propensity_score,)) - -CMRelevantFactors(outcome_mean, propensity_score) = - CMRelevantFactors(outcome_mean, propensity_score) - -CMRelevantFactors(;outcome_mean, propensity_score) = - CMRelevantFactors(outcome_mean, propensity_score) - -string_repr(estimand::CMRelevantFactors) = - string("Composite Factor: \n", - "----------------\n- ", - string_repr(estimand.outcome_mean),"\n- ", - join((string_repr(f) for f in estimand.propensity_score), "\n- ")) - -variables(estimand::CMRelevantFactors) = - Tuple(union(variables(estimand.outcome_mean), (variables(est) for est in estimand.propensity_score)...)) +const ExpectedValue = ConditionalDistribution \ No newline at end of file diff --git a/src/estimates.jl b/src/estimates.jl index 3bb5b743..5e33ec1d 100644 --- a/src/estimates.jl +++ b/src/estimates.jl @@ -90,193 +90,3 @@ function compute_offset(estimate::ConditionalDistributionEstimate, X) ŷ = predict(estimate, X) return compute_offset(ŷ) end - -##################################################################### -### MLCMRelevantFactors ### -##################################################################### - -""" -Holds a Sample Split Machine Learning set of estimates (outcome mean, propensity score) -for counterfactual mean based estimands' relevant factors. -""" -struct MLCMRelevantFactors <: Estimate - estimand::CMRelevantFactors - outcome_mean::ConditionalDistributionEstimate - propensity_score::Tuple{Vararg{ConditionalDistributionEstimate}} -end - -string_repr(estimate::MLCMRelevantFactors) = string( - "Composite Factor Estimate: \n", - "-------------------------\n- ", - string_repr(estimate.outcome_mean),"\n- ", - join((string_repr(f) for f in estimate.propensity_score), "\n- ") -) - - -##################################################################### -### One Dimensional Estimates ### -##################################################################### - - -struct TMLEstimate{T<:AbstractFloat} <: Estimate - Ψ̂::T - IC::Vector{T} -end - -struct OSEstimate{T<:AbstractFloat} <: Estimate - Ψ̂::T - IC::Vector{T} -end - -const EICEstimate = Union{TMLEstimate, OSEstimate} - -function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate) - testresult = OneSampleTTest(est) - data = [estimate(est) confint(testresult) pvalue(testresult);] - pretty_table(io, data;header=["Estimate", "95% Confidence Interval", "P-value"]) -end - -struct ComposedEstimate{T<:AbstractFloat} <: Estimate - Ψ̂::Array{T} - σ̂::Matrix{T} - n::Int -end - -function Base.show(io::IO, ::MIME"text/plain", est::ComposedEstimate) - if length(est.σ̂) !== 1 - println(io, string("Estimate: ", estimate(est), "\nVariance: \n", var(est))) - else - testresult = OneSampleTTest(est) - data = [estimate(est) confint(testresult) pvalue(testresult);] - headers = ["Estimate", "95% Confidence Interval", "P-value"] - pretty_table(io, data;header=headers) - end -end - -""" - Distributions.estimate(r::EICEstimate) - -Retrieves the final estimate: after the TMLE step. -""" -Distributions.estimate(est::EICEstimate) = est.Ψ̂ - -""" - Distributions.estimate(r::ComposedEstimate) - -Retrieves the final estimate: after the TMLE step. -""" -Distributions.estimate(est::ComposedEstimate) = - length(est.Ψ̂) == 1 ? est.Ψ̂[1] : est.Ψ̂ - -""" - var(r::EICEstimate) - -Computes the estimated variance associated with the estimate. -""" -Statistics.var(est::EICEstimate) = - var(est.IC)/size(est.IC, 1) - -""" - var(r::ComposedEstimate) - -Computes the estimated variance associated with the estimate. -""" -Statistics.var(est::ComposedEstimate) = - length(est.σ̂) == 1 ? est.σ̂[1] / est.n : est.σ̂ ./ est.n - - -""" - OneSampleZTest(r::EICEstimate, Ψ₀=0) - -Performs a Z test on the EICEstimate. -""" -HypothesisTests.OneSampleZTest(est::EICEstimate, Ψ₀=0) = - OneSampleZTest(estimate(est), std(est.IC), size(est.IC, 1), Ψ₀) - -""" - OneSampleTTest(r::EICEstimate, Ψ₀=0) - -Performs a T test on the EICEstimate. -""" -HypothesisTests.OneSampleTTest(est::EICEstimate, Ψ₀=0) = - OneSampleTTest(estimate(est), std(est.IC), size(est.IC, 1), Ψ₀) - -""" - OneSampleTTest(r::ComposedEstimate, Ψ₀=0) - -Performs a T test on the ComposedEstimate. -""" -function HypothesisTests.OneSampleTTest(est::ComposedEstimate, Ψ₀=0) - @assert length(est.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." - return OneSampleTTest(estimate(est), sqrt(est.σ̂[1]), est.n, Ψ₀) -end - -""" - OneSampleZTest(r::ComposedEstimate, Ψ₀=0) - -Performs a T test on the ComposedEstimate. -""" -function HypothesisTests.OneSampleZTest(est::ComposedEstimate, Ψ₀=0) - @assert length(est.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." - return OneSampleZTest(estimate(est), sqrt(est.σ̂[1]), est.n, Ψ₀) -end - -""" - compose(f, estimation_results::Vararg{EICEstimate, N}) where N - -Provides an estimator of f(estimation_results...). - -# Mathematical details - -The following is a summary from `Asymptotic Statistics`, A. W. van der Vaart. - -Consider k TMLEs computed from a dataset of size n and embodied by Tₙ = (T₁,ₙ, ..., Tₖ,ₙ). -Since each of them is asymptotically normal, the multivariate CLT provides the joint -distribution: - - √n(Tₙ - Ψ₀) ↝ N(0, Σ), - -where Σ is the covariance matrix of the TMLEs influence curves. - -Let f:ℜᵏ→ℜᵐ, be a differentiable map at Ψ₀. Then, the delta method provides the -limiting distribution of √n(f(Tₙ) - f(Ψ₀)). Because Tₙ is normal, the result is: - - √n(f(Tₙ) - f(Ψ₀)) ↝ N(0, ∇f(Ψ₀) ̇Σ ̇(∇f(Ψ₀))ᵀ), - -where ∇f(Ψ₀):ℜᵏ→ℜᵐ is a linear map such that by abusing notations and identifying the -function with the multiplication matrix: ∇f(Ψ₀):h ↦ ∇f(Ψ₀) ̇h. And the matrix ∇f(Ψ₀) is -the jacobian of f at Ψ₀. - -Hence, the only thing we need to do is: -- Compute the covariance matrix Σ -- Compute the jacobian ∇f, which can be done using Julia's automatic differentiation facilities. -- The final estimator is normal with mean f₀=f(Ψ₀) and variance σ₀=∇f(Ψ₀) ̇Σ ̇(∇f(Ψ₀))ᵀ - -# Arguments - -- f: An array-input differentiable map. -- estimation_results: 1 or more `EICEstimate` structs. - -# Examples - -Assuming `res₁` and `res₂` are TMLEs: - -```julia -f(x, y) = [x^2 - y, y - 3x] -compose(f, res₁, res₂) -``` -""" -function compose(f, estimators::Vararg{EICEstimate, N}; backend=AD.ZygoteBackend()) where N - Σ = cov(estimators...) - estimates = [estimate(r) for r in estimators] - f₀, Js = AD.value_and_jacobian(backend, f, estimates...) - J = hcat(Js...) - n = size(first(estimators).IC, 1) - σ₀ = J*Σ*J' - return ComposedEstimate(collect(f₀), σ₀, n) -end - -function Statistics.cov(estimators::Vararg{EICEstimate, N}) where N - X = hcat([r.IC for r in estimators]...) - return Statistics.cov(X, dims=1, corrected=true) -end diff --git a/test/counterfactual_mean_based/gradient.jl b/test/counterfactual_mean_based/gradient.jl index 33ca8f0b..f610ad7a 100644 --- a/test/counterfactual_mean_based/gradient.jl +++ b/test/counterfactual_mean_based/gradient.jl @@ -27,36 +27,41 @@ end @testset "Test gradient_and_estimate" begin ps_lowerbound = 1e-8 Ψ = ATE( - outcome = :Y, - treatment = (T=(case=1, control=0),), - confounders = [:W], - treatment_model = LogisticClassifier(), - outcome_model = with_encoder(InteractionTransformer(order=2) |> LinearRegressor()) + outcome = :Y, + treatment_values = (T=(case=1, control=0),), + treatment_confounders = (T=[:W],), ) dataset = one_treatment_dataset(;n=100) - factors = fit!(Ψ, dataset, verbosity=0) - - # Retrieve outcome model - Q = factors.Y - G = (T=factors.T, ) - linear_model = fitted_params(Q).deterministic_pipeline.linear_regressor + η = TMLE.CMRelevantFactors( + TMLE.ConditionalDistribution(:Y, [:T, :W]), + TMLE.ConditionalDistribution(:T, [:W]) + ) + η̂ = TMLE.CMRelevantFactorsEstimator( + nothing, + (Y=with_encoder(InteractionTransformer(order=2) |> LinearRegressor()), T = LogisticClassifier()) + ) + η̂ₙ = η̂(η, dataset, verbosity = 0) + # Retrieve conditional distributions and fitted_params + Q = η̂ₙ.outcome_mean + G = η̂ₙ.propensity_score + linear_model = fitted_params(Q.machine).deterministic_pipeline.linear_regressor intercept = linear_model.intercept coefs = Dict(linear_model.coefs) # Counterfactual aggregate - ctf_agg = TMLE.counterfactual_aggregate(Ψ, Q, G) + ctf_agg = TMLE.counterfactual_aggregate(Ψ, η̂ₙ.outcome_mean, dataset) expected_ctf_agg = (intercept .+ coefs[:T] .+ dataset.W.*coefs[:W] .+ dataset.W.*coefs[:T_W]) .- (intercept .+ dataset.W.*coefs[:W]) @test ctf_agg ≈ expected_ctf_agg atol=1e-10 # Gradient Y|X - H = 1 ./ pdf.(predict(G.T.machine), dataset.T) .* [t == 1 ? 1. : -1. for t in dataset.T] + H = 1 ./ pdf.(predict(G[1].machine), dataset.T) .* [t == 1 ? 1. : -1. for t in dataset.T] expected_∇YX = H .* (dataset.Y .- predict(Q.machine)) - ∇YX = TMLE.∇YX(Ψ, Q, G; ps_lowerbound=ps_lowerbound) + ∇YX = TMLE.∇YX(Ψ, Q, G, dataset; ps_lowerbound=ps_lowerbound) @test expected_∇YX == ∇YX # Gradient W expectedΨ̂ = mean(ctf_agg) ∇W = TMLE.∇W(ctf_agg, expectedΨ̂) @test ∇W == ctf_agg .- expectedΨ̂ # gradient_and_estimate - IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, factors; ps_lowerbound=ps_lowerbound) + IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, η̂ₙ, dataset; ps_lowerbound=ps_lowerbound) @test expectedΨ̂ == Ψ̂ @test IC == ∇YX .+ ∇W end From 948e7b24402701c1fd33c9c0f0e3d528ccebd90b Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sat, 23 Sep 2023 20:55:13 +0100 Subject: [PATCH 073/151] more updates --- src/counterfactual_mean_based/estimators.jl | 2 +- .../3points_interactions.jl | 33 ++-- test/counterfactual_mean_based/fluctuation.jl | 168 +++++++----------- test/helper_fns.jl | 11 +- 4 files changed, 95 insertions(+), 119 deletions(-) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 3afcf327..df0f45bb 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -81,7 +81,7 @@ function (estimator::TargetedCMRelevantFactorsEstimator)(estimand, dataset; cach # Do not fluctuate propensity score fluctuated_propensity_score = model.initial_factors.propensity_score # Build estimate - estimate = MLCMRelevantFactors(estimand, fluctuated_outcome_mean, fluctuated_propensity_score) + estimate = MLCMRelevantFactors(estimand, fluctuated_outcome_mean, fluctuated_propensity_score) # Update cache cache[:last_fluctuation] = estimate diff --git a/test/counterfactual_mean_based/3points_interactions.jl b/test/counterfactual_mean_based/3points_interactions.jl index 382748c9..38b5ddee 100644 --- a/test/counterfactual_mean_based/3points_interactions.jl +++ b/test/counterfactual_mean_based/3points_interactions.jl @@ -21,31 +21,36 @@ function dataset_scm_and_truth(;n=1000) Y = 2 .- 2T₁.*T₂.*T₃.*(W .+ 10) + rand(rng, Normal(0, 0.03), n) dataset = (W=W, T₁=categorical(T₁), T₂=categorical(T₂), T₃=categorical(T₃), Y=Y) - scm = StaticConfoundedModel( - [:Y], - [:T₁, :T₂, :T₃], - :W, - outcome_model = TreatmentTransformer() |> InteractionTransformer(order=3) |> LinearRegressor(), - treatment_model = LogisticClassifier(lambda=0) - ) truth = -21 - return dataset, scm, truth + return dataset, truth end @testset "Test 3-points interactions" begin - dataset, scm, Ψ₀ = dataset_scm_and_truth(;n=1000) + dataset, Ψ₀ = dataset_scm_and_truth(;n=1000) Ψ = IATE( - scm, outcome = :Y, - treatment = (T₁=(case=true, control=false), T₂=(case=true, control=false), T₃=(case=1, control=0)) + treatment_values = ( + T₁=(case=true, control=false), + T₂=(case=true, control=false), + T₃=(case=1, control=0) + ), + treatment_confounders = (T₁=[:W], T₂=[:W], T₃=[:W]) + ) + models = ( + Y = TreatmentTransformer() |> InteractionTransformer(order=3) |> LinearRegressor(), + T₁ = LogisticClassifier(lambda=0), + T₂ = LogisticClassifier(lambda=0), + T₃ = LogisticClassifier(lambda=0) ) - result, fluctuation = tmle!(Ψ, dataset, verbosity=0) + tmle = TMLEE(models) + result, cache = tmle(Ψ, dataset, verbosity=0); test_coverage(result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation) + test_fluct_decreases_risk(cache) test_mean_inf_curve_almost_zero(result; atol=1e-10) - result, fluctuation = tmle!(Ψ, dataset, verbosity=0, weighted_fluctuation=true) + tmle.weighted = true + result, cache = tmle(Ψ, dataset, verbosity=0) test_coverage(result, Ψ₀) test_mean_inf_curve_almost_zero(result; atol=1e-10) diff --git a/test/counterfactual_mean_based/fluctuation.jl b/test/counterfactual_mean_based/fluctuation.jl index 613d86f2..ceceb986 100644 --- a/test/counterfactual_mean_based/fluctuation.jl +++ b/test/counterfactual_mean_based/fluctuation.jl @@ -4,81 +4,76 @@ using Test using TMLE using MLJModels using MLJBase - -function test_fluctuation_fit(Ψ, ps_lowerbound, weighted_fluctuation) - Q⁰ = TMLE.get_outcome_model(Ψ) - X, y = Q⁰.data - Q = machine( - TMLE.Fluctuation(Ψ, 0.1, ps_lowerbound, weighted_fluctuation), - X, - y - ) - fit!(Q, verbosity=0) - Xfluct, weights = TMLE.clever_covariate_offset_and_weights(Ψ, Q⁰, X; - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation - ) - @test Q.cache.weighted_covariate == Xfluct.covariate .* weights -end +using DataFrames @testset "Test Fluctuation with 1 Treatments" begin - ### In all cases, models are "constant" models - ## First case: 1 Treamtent variable Ψ = ATE( outcome=:Y, - confounders=[:W₁, :W₂, :W₃], - treatment=(T=(case="a", control="b"),), - treatment_model = ConstantClassifier(), - outcome_model = ConstantRegressor() + treatment_confounders=(T=[:W₁, :W₂, :W₃],), + treatment_values=(T=(case="a", control="b"),), ) - dataset = ( + dataset = DataFrame( T = categorical(["a", "b", "c", "a", "a", "b", "a"]), Y = [1., 2., 3, 4, 5, 6, 7], W₁ = rand(7), W₂ = rand(7), W₃ = rand(7) ) - initial_factors = fit!(Ψ, dataset, verbosity=0) - ps_lowerbound = 1e-8 - weighted_fluctuation = true - targeted_factors = TMLE.fluctuate(initial_factors, Ψ; - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation, - verbosity=1, - tol=nothing + η = TMLE.CMRelevantFactors( + TMLE.ConditionalDistribution(:Y, [:T, :W₁, :W₂, :W₃]), + TMLE.ConditionalDistribution(:T, [:W₁, :W₂, :W₃]) ) - @test targeted_factors.T === initial_factors.T - @test targeted_factors.Y !== initial_factors.Y - @test targeted_factors.Y.machine.cache.weighted_covariate == [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] - Q = targeted_factors.Y - X = Q.machine.data[1] - Xfluct, weights = TMLE.clever_covariate_offset_and_weights(Q.model, X) - @test weights == [1.75, 3.5, 7.0, 1.75, 1.75, 3.5, 1.75] - @test targeted_factors.Y.machine.cache.weighted_covariate == Xfluct.covariate .* weights + η̂ = TMLE.CMRelevantFactorsEstimator( + nothing, + (Y=with_encoder(ConstantRegressor()), T = ConstantClassifier()) + ) + η̂ₙ = η̂(η, dataset, verbosity = 0) + X = dataset[!, collect(η̂ₙ.outcome_mean.estimand.parents)] + y = dataset[!, η̂ₙ.outcome_mean.estimand.outcome] - weighted_fluctuation = false - targeted_factors = TMLE.fluctuate(initial_factors, Ψ; + ps_lowerbound = 1e-8 + # Weighted fluctuation + expected_weights = [1.75, 3.5, 7., 1.75, 1.75, 3.5, 1.75] + expected_covariate = [1., -1., 0.0, 1., 1., -1., 1.] + fluctuation = TMLE.Fluctuation(Ψ, η̂ₙ; + tol=nothing, ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation, - verbosity=1, - tol=nothing + weighted=true, ) - Q = targeted_factors.Y - X = Q.machine.data[1] - Xfluct, weights = TMLE.clever_covariate_offset_and_weights(Q.model, X) + fitresult, cache, report = MLJBase.fit(fluctuation, 0, X, y) + @test fitted_params(fitresult.one_dimensional_path).features == [:covariate] + @test cache.weighted_covariate == expected_weights .* expected_covariate + @test cache.training_expected_value isa AbstractVector + Xfluct, weights = TMLE.clever_covariate_offset_and_weights(fluctuation, X) + @test weights == expected_weights + @test fitresult.one_dimensional_path.data[3] == expected_weights + @test fitresult.one_dimensional_path.data[1].covariate == expected_covariate + # Unweighted fluctuation + fluctuation.weighted = false + expected_weights = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + expected_covariate = [1.75, -3.5, 0.0, 1.75, 1.75, -3.5, 1.75] + fitresult, cache, report = MLJBase.fit(fluctuation, 0, X, y) + @test fitresult.one_dimensional_path.data[3] == expected_weights + Xfluct, weights = TMLE.clever_covariate_offset_and_weights(fluctuation, X) @test weights == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + @test Xfluct.covariate == expected_covariate + @test fitresult.one_dimensional_path.data[1].covariate == expected_covariate end @testset "Test Fluctuation with 2 Treatments" begin Ψ = IATE( outcome =:Y, - treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)), - confounders = [:W₁, :W₂, :W₃], - treatment_model = ConstantClassifier(), - outcome_model = ConstantClassifier() + treatment_values=( + T₁=(case=1, control=0), + T₂=(case=1, control=0) + ), + treatment_confounders = ( + T₁=[:W₁, :W₂, :W₃], + T₂=[:W₁, :W₂, :W₃] + ) ) - dataset = ( + dataset = DataFrame( T₁ = categorical([1, 0, 0, 1, 1, 1, 0]), T₂ = categorical([1, 1, 1, 1, 1, 0, 0]), Y = categorical([1, 1, 1, 1, 0, 0, 0]), @@ -86,57 +81,32 @@ end W₂ = rand(7), W₃ = rand(7) ) - ps_lowerbound = 1e-8 - weighted_fluctuation = false - - initial_factors = fit!(Ψ, dataset, verbosity=0) - targeted_factors = TMLE.fluctuate(initial_factors, Ψ; - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation, - verbosity=1, - tol=nothing - ) - @test targeted_factors.T₁ === initial_factors.T₁ - @test targeted_factors.T₂ === initial_factors.T₂ - @test targeted_factors.Y !== initial_factors.Y - @test targeted_factors.Y.machine.cache.weighted_covariate ≈ - [2.45, -3.27, -3.27, 2.45, 2.45, -6.13, 8.17] atol=0.01 -end - -@testset "Test Fluctuation with 3 Treatments" begin - ## Third case: 3 Treatment variables - Ψ = IATE( - outcome = :Y, - treatment=(T₁=(case="a", control="b"), - T₂=(case=1, control=2), - T₃=(case=true, control=false)), - confounders=:W, - treatment_model = ConstantClassifier(), - outcome_model = DeterministicConstantRegressor() + η = TMLE.CMRelevantFactors( + TMLE.ConditionalDistribution(:Y, [:T₁, :T₂, :W₁, :W₂, :W₃]), + ( + TMLE.ConditionalDistribution(:T₁, [:W₁, :W₂, :W₃]), + TMLE.ConditionalDistribution(:T₂, [:W₁, :W₂, :W₃]), + ) ) - dataset = ( - T₁ = categorical(["a", "a", "b", "b", "c", "b", "b"]), - T₂ = categorical([3, 2, 1, 1, 2, 2, 2], ordered=true), - T₃ = categorical([true, false, true, false, false, false, false], ordered=true), - Y = [1., 2., 3, 4, 5, 6, 7], - W = rand(7), + η̂ = TMLE.CMRelevantFactorsEstimator( + nothing, + ( + Y = with_encoder(ConstantClassifier()), + T₁ = ConstantClassifier(), + T₂ = ConstantClassifier() + ) ) - ps_lowerbound = 1e-8 - weighted_fluctuation = false + η̂ₙ = η̂(η, dataset, verbosity = 0) + X = dataset[!, collect(η̂ₙ.outcome_mean.estimand.parents)] + y = dataset[!, η̂ₙ.outcome_mean.estimand.outcome] - initial_factors = fit!(Ψ, dataset, verbosity=0) - targeted_factors = TMLE.fluctuate(initial_factors, Ψ; - ps_lowerbound=ps_lowerbound, - weighted_fluctuation=weighted_fluctuation, - verbosity=1, - tol=nothing + fluctuation = TMLE.Fluctuation(Ψ, η̂ₙ; + tol=nothing, + ps_lowerbound=1e-8, + weighted=false, ) - @test targeted_factors.T₁ === initial_factors.T₁ - @test targeted_factors.T₂ === initial_factors.T₂ - @test targeted_factors.T₃ === initial_factors.T₃ - @test targeted_factors.Y !== initial_factors.Y - @test targeted_factors.Y.machine.cache.weighted_covariate ≈ - [0.0, 8.58, -21.44, 8.58, 0.0, -4.29, -4.29] atol=0.01 + fitresult, cache, report = MLJBase.fit(fluctuation, 0, X, y) + @test cache.weighted_covariate ≈ [2.45, -3.27, -3.27, 2.45, 2.45, -6.13, 8.17] atol=0.01 end @testset "Test fluctuation_input" begin diff --git a/test/helper_fns.jl b/test/helper_fns.jl index 741940a2..c9a2a6bf 100644 --- a/test/helper_fns.jl +++ b/test/helper_fns.jl @@ -20,11 +20,12 @@ The fluctuation is supposed to decrease the risk as its objective function is th It seems that sometimes this is not entirely true in practice, so the test actually checks that it does not increase risk more than tol """ -function test_fluct_decreases_risk(Ψ::TMLE.CMCompositeEstimand, fluctuation_mach; atol=1e-6) - outcome_mach = TMLE.get_outcome_model(Ψ) - y = outcome_mach.data[2] - initial_risk = risk(MLJBase.predict(outcome_mach), y) - fluct_risk = risk(MLJBase.predict(fluctuation_mach), y) +function test_fluct_decreases_risk(cache; atol=1e-6) + fluctuated_mean_machine = cache[:last_fluctuation].outcome_mean.machine + initial_mean_machine = fluctuated_mean_machine.model.initial_factors.outcome_mean.machine + y = fluctuated_mean_machine.data[2] + initial_risk = risk(MLJBase.predict(initial_mean_machine), y) + fluct_risk = risk(MLJBase.predict(fluctuated_mean_machine), y) @test initial_risk >= fluct_risk || isapprox(initial_risk, fluct_risk, atol=atol) end From 9ce592dd1a84e9e68db17a1ab124e074487674aa Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sat, 23 Sep 2023 21:28:35 +0100 Subject: [PATCH 074/151] fix dbl robust ate --- .../double_robustness_ate.jl | 230 +++++++++++------- 1 file changed, 142 insertions(+), 88 deletions(-) diff --git a/test/counterfactual_mean_based/double_robustness_ate.jl b/test/counterfactual_mean_based/double_robustness_ate.jl index dc0fe023..6a2eacc2 100644 --- a/test/counterfactual_mean_based/double_robustness_ate.jl +++ b/test/counterfactual_mean_based/double_robustness_ate.jl @@ -37,10 +37,8 @@ function binary_outcome_binary_treatment_pb(;n=100) ATE₁ = py_given_aw(1, 1)*p_w() + (1-p_w())*py_given_aw(1, 0) ATE₀ = py_given_aw(0, 1)*p_w() + (1-p_w())*py_given_aw(0, 0) ATE = ATE₁ - ATE₀ - # Define the SCM - scm = StaticConfoundedModel(:Y, :T, :W) - - return dataset, scm, ATE + + return dataset, ATE end """ @@ -59,10 +57,7 @@ function continuous_outcome_binary_treatment_pb(;n=100) dataset = (T = T, W₁ = W₁, W₂ = W₂, W₃ = W₃, Y = y) # Theroretical ATE ATE = 4 - # Define the SCM - scm = StaticConfoundedModel(:Y, :T, [:W₁, :W₂, :W₃]) - - return dataset, scm, ATE + return dataset, ATE end function continuous_outcome_categorical_treatment_pb(;n=100, control="TT", case="AA") @@ -79,9 +74,7 @@ function continuous_outcome_categorical_treatment_pb(;n=100, control="TT", case= dataset = (T = categorical(T), W₁ = W₁, W₂ = W₂, W₃ = W₃, Y = y) # True ATE: Ew[E[Y|t,w]] = ∑ᵤ (ft(T) + fw(w))p(w) = ft(t) + 0.5 ATE = (ft(case) + 0.5) - (ft(control) + 0.5) - # Define the SCM - scm = StaticConfoundedModel(:Y, :T, [:W₁, :W₂, :W₃]) - return dataset, scm, ATE + return dataset, ATE end @@ -107,168 +100,229 @@ function dataset_2_treatments_pb(;rng = StableRNG(123), n=100) control = zeros(n) ATE₁₁₋₀₁ = mean(μY(W₁, W₂, case, case) .- μY(W₁, W₂, control, case)) ATE₁₁₋₀₀ = mean(μY(W₁, W₂, case, case) .- μY(W₁, W₂, control, control)) - # Define the SCM - scm = StaticConfoundedModel([:Y], [:T₁, :T₂], [:W₁, :W₂]) - return dataset, scm , (ATE₁₁₋₀₁, ATE₁₁₋₀₀) + return dataset, (ATE₁₁₋₀₁, ATE₁₁₋₀₀) end @testset "Test Double Robustness ATE on continuous_outcome_categorical_treatment_pb" begin - dataset, scm, Ψ₀ = continuous_outcome_categorical_treatment_pb(;n=10_000, control="TT", case="AA") + dataset, Ψ₀ = continuous_outcome_categorical_treatment_pb(;n=10_000, control="TT", case="AA") Ψ = ATE( - scm, outcome = :Y, - treatment = (T=(case="AA", control="TT"),), - ) + treatment_values = (T=(case="AA", control="TT"),), + treatment_confounders = (T = [:W₁, :W₂, :W₃],) + ) + # When Q is misspecified but G is well specified - scm.Y.model = TreatmentTransformer() |> MLJModels.DeterministicConstantRegressor() - scm.T.model = LogisticClassifier(lambda=0) - - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) - ose_result, Q⁰ = ose!(Ψ, dataset, verbosity=0) + models = ( + Y = with_encoder(MLJModels.DeterministicConstantRegressor()), + T = LogisticClassifier(lambda=0) + ) + tmle = TMLEE(models) + ose = OSE(models) + naive = NAIVE(models.Y) + cache = Dict() + tmle_result, cache = tmle(Ψ, dataset, cache=cache, verbosity=0) + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) test_coverage(tmle_result, Ψ₀) test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away - @test naive_plugin_estimate!(Ψ, dataset, verbosity=0) == 0 + naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) + @test naive_result == 0 # When Q is well specified but G is misspecified - scm.Y.model = TreatmentTransformer() |> LinearRegressor() - scm.T.model = ConstantClassifier() + models = ( + Y = with_encoder(TreatmentTransformer() |> LinearRegressor()), + T = ConstantClassifier() + ) + tmle.models = models + ose.models = models + naive.model = models.Y - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) - ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) test_coverage(tmle_result, Ψ₀) test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end @testset "Test Double Robustness ATE on binary_outcome_binary_treatment_pb" begin - dataset, scm, Ψ₀ = binary_outcome_binary_treatment_pb(;n=10_000) + dataset, Ψ₀ = binary_outcome_binary_treatment_pb(;n=10_000) Ψ = ATE( - scm, outcome = :Y, - treatment = (T=(case=true, control=false),), + treatment_values = (T=(case=true, control=false),), + treatment_confounders = (T=[:W],) ) # When Q is misspecified but G is well specified - scm.Y.model = TreatmentTransformer() |> ConstantClassifier() - scm.T.model = LogisticClassifier(lambda=0) + models = ( + Y = with_encoder(ConstantClassifier()), + T = LogisticClassifier(lambda=0) + ) + tmle = TMLEE(models) + ose = OSE(models) + naive = NAIVE(models.Y) + cache = Dict() - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) - ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) test_coverage(tmle_result, Ψ₀) test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-6) test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away - @test naive_plugin_estimate!(Ψ, dataset, verbosity=0) == 0 - + naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) + @test naive_result == 0 # When Q is well specified but G is misspecified - scm.Y.model = TreatmentTransformer() |> LogisticClassifier(lambda=0) - scm.T.model = ConstantClassifier() + models = ( + Y = with_encoder(LogisticClassifier(lambda=0)), + T = ConstantClassifier() + ) + tmle.models = models + ose.models = models + naive.model = models.Y - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) - ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) test_coverage(tmle_result, Ψ₀) test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-6) end @testset "Test Double Robustness ATE on continuous_outcome_binary_treatment_pb" begin - dataset, scm, Ψ₀ = continuous_outcome_binary_treatment_pb(n=10_000) + dataset, Ψ₀ = continuous_outcome_binary_treatment_pb(n=10_000) Ψ = ATE( - scm, - outcome = :Y, - treatment = (T=(case=true, control=false),), + outcome = :Y, + treatment_values = (T=(case=true, control=false),), + treatment_confounders = (T=[:W₁, :W₂, :W₃],) ) # When Q is misspecified but G is well specified - scm.Y.model = TreatmentTransformer() |> MLJModels.DeterministicConstantRegressor() - scm.T.model = LogisticClassifier(lambda=0) + models = ( + Y = with_encoder(MLJModels.DeterministicConstantRegressor()), + T = LogisticClassifier(lambda=0) + ) + tmle = TMLEE(models) + ose = OSE(models) + naive = NAIVE(models.Y) + cache = Dict() - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) - ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) test_coverage(tmle_result, Ψ₀) test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away - @test naive_plugin_estimate(Ψ) == 0 + naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) + @test naive_result == 0 # When Q is well specified but G is misspecified - scm.Y.model = TreatmentTransformer() |> LinearRegressor() - scm.T.model = ConstantClassifier() + models = ( + Y = with_encoder(LinearRegressor()), + T = ConstantClassifier() + ) + tmle.models = models + ose.models = models + naive.model = models.Y - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) - ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) test_coverage(tmle_result, Ψ₀) test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end @testset "Test Double Robustness ATE with two treatment variables" begin - dataset, scm, (ATE₁₁₋₀₁, ATE₁₁₋₀₀) = dataset_2_treatments_pb(;rng = StableRNG(123), n=50_000) + dataset, (ATE₁₁₋₀₁, ATE₁₁₋₀₀) = dataset_2_treatments_pb(;rng = StableRNG(123), n=50_000) # Test first ATE, only T₁ treatment varies Ψ = ATE( - scm, - outcome = :Y, - treatment = (T₁=(case=1., control=0.), T₂=(case=1., control=1.)), + outcome = :Y, + treatment_values = ( + T₁=(case=1., control=0.), + T₂=(case=1., control=1.) + ), + treatment_confounders = ( + T₁ = [:W₁, :W₂], + T₂ = [:W₁, :W₂], + ) ) # When Q is misspecified but G is well specified - scm.Y.model = TreatmentTransformer() |> MLJModels.DeterministicConstantRegressor() - scm.T₁.model = LogisticClassifier(lambda=0) - scm.T₂.model = LogisticClassifier(lambda=0) + models = ( + Y = with_encoder(MLJModels.DeterministicConstantRegressor()), + T₁ = LogisticClassifier(lambda=0), + T₂ = LogisticClassifier(lambda=0) + ) + tmle = TMLEE(models) + ose = OSE(models) + naive = NAIVE(models.Y) + cache = Dict() - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) - ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₁) test_coverage(ose_result, ATE₁₁₋₀₁) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # When Q is well specified but G is misspecified - scm.Y.model = TreatmentTransformer() |> LinearRegressor() - scm.T₁.model = ConstantClassifier() - scm.T₂.model = ConstantClassifier() + models = ( + Y = with_encoder(LinearRegressor()), + T₁ = ConstantClassifier(), + T₂ = ConstantClassifier() + ) + tmle.models = models + ose.models = models + naive.model = models.Y - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) - ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₁) test_coverage(ose_result, ATE₁₁₋₀₁) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) # Test second ATE, two treatment varies Ψ = ATE( - scm, - outcome = :Y, - treatment = (T₁=(case=1., control=0.), T₂=(case=1., control=0.)), + outcome = :Y, + treatment_values = ( + T₁=(case=1., control=0.), + T₂=(case=1., control=0.) + ), + treatment_confounders = ( + T₁ = [:W₁, :W₂], + T₂ = [:W₁, :W₂], + ) ) - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) - ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₀) test_coverage(ose_result, ATE₁₁₋₀₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # When Q is well specified but G is misspecified - scm.Y.model = TreatmentTransformer() |> MLJModels.DeterministicConstantRegressor() - scm.T₁.model = LogisticClassifier(lambda=0) - scm.T₂.model = LogisticClassifier(lambda=0) + models = ( + Y = with_encoder(MLJModels.DeterministicConstantRegressor()), + T₁ = LogisticClassifier(lambda=0), + T₂ = LogisticClassifier(lambda=0), + ) + tmle.models = models + ose.models = models + naive.model = models.Y - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) - ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0) + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) test_coverage(tmle_result, ATE₁₁₋₀₀) test_coverage(ose_result, ATE₁₁₋₀₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end From e57ff4878fd57a42403acd25d8715dcc911c15a6 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sat, 23 Sep 2023 21:44:06 +0100 Subject: [PATCH 075/151] fix db robust iate --- src/counterfactual_mean_based/estimators.jl | 32 +++- .../double_robustness_iate.jl | 174 +++++++++++------- 2 files changed, 132 insertions(+), 74 deletions(-) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index df0f45bb..18f095ad 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -7,6 +7,9 @@ struct CMRelevantFactorsEstimator <: Estimator models end +key(estimator::CMRelevantFactorsEstimator) = + (CMRelevantFactorsEstimator, estimator.resampling, estimator.models) + get_train_validation_indices(resampling::Nothing, factors, dataset) = nothing function get_train_validation_indices(resampling::ResamplingStrategy, factors, dataset) @@ -22,11 +25,13 @@ function get_train_validation_indices(resampling::ResamplingStrategy, factors, d end function (estimator::CMRelevantFactorsEstimator)(estimand, dataset; cache=Dict(), verbosity=1) - # Lookup in cache - if haskey(cache, (estimand, estimator)) - return from_cache(cache, estimand, estimator; verbosity=verbosity) + cache_key = key(estimand, estimator) + if haskey(cache, cache_key) + estimate = cache[cache_key] + verbosity > 0 && @info(string("Reusing estimate for: ", string_repr(estimand))) + return estimate end - # Otherwise estimate + verbosity > 0 && @info(string("Estimating: ", string_repr(estimand))) models = estimator.models # Get train validation indices train_validation_indices = TMLE.get_train_validation_indices(estimator.resampling, estimand, dataset) @@ -145,7 +150,7 @@ end OSE(models; resampling=nothing, ps_lowerbound=1e-8) = OSE(models, resampling, ps_lowerbound) -function (estimator::OSE)(Ψ::CMCompositeEstimand, dataset; verbosity=1) +function (estimator::OSE)(Ψ::CMCompositeEstimand, dataset; cache=Dict(), verbosity=1) # Check the estimand against the dataset TMLE.check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's relevant factors @@ -154,7 +159,12 @@ function (estimator::OSE)(Ψ::CMCompositeEstimand, dataset; verbosity=1) nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(initial_factors)) initial_factors_dataset = TMLE.choose_initial_dataset(dataset, nomissing_dataset, estimator.resampling) initial_factors_estimator = TMLE.CMRelevantFactorsEstimator(estimator.resampling, estimator.models) - initial_factors_estimate = initial_factors_estimator(initial_factors, initial_factors_dataset, verbosity=verbosity) + initial_factors_estimate = initial_factors_estimator( + initial_factors, + initial_factors_dataset; + cache=cache, + verbosity=verbosity + ) # Get propensity score truncation threshold n = nrows(nomissing_dataset) ps_lowerbound = TMLE.ps_lower_bound(n, estimator.ps_lowerbound) @@ -162,7 +172,7 @@ function (estimator::OSE)(Ψ::CMCompositeEstimand, dataset; verbosity=1) # Gradient and estimate IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) verbosity >= 1 && @info "Done." - return OSEstimate(Ψ̂ + mean(IC), IC), initial_factors_estimate + return OSEstimate(Ψ̂ + mean(IC), IC), cache end ##################################################################### @@ -173,7 +183,7 @@ mutable struct NAIVE <: Estimator model end -function (estimator::NAIVE)(Ψ::CMCompositeEstimand, dataset; verbosity=1) +function (estimator::NAIVE)(Ψ::CMCompositeEstimand, dataset; cache=Dict(), verbosity=1) # Check the estimand against the dataset TMLE.check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's relevant factors @@ -182,8 +192,10 @@ function (estimator::NAIVE)(Ψ::CMCompositeEstimand, dataset; verbosity=1) nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(relevant_factors)) outcome_mean_estimate = MLConditionalDistributionEstimator(estimator.model)( relevant_factors.outcome_mean, - dataset; + dataset; + cache=cache, verbosity=verbosity ) - return mean(counterfactual_aggregate(Ψ, outcome_mean_estimate, nomissing_dataset)) + Ψ̂ = mean(counterfactual_aggregate(Ψ, outcome_mean_estimate, nomissing_dataset)) + return Ψ̂, cache end \ No newline at end of file diff --git a/test/counterfactual_mean_based/double_robustness_iate.jl b/test/counterfactual_mean_based/double_robustness_iate.jl index d94a3a8e..f91f2874 100644 --- a/test/counterfactual_mean_based/double_robustness_iate.jl +++ b/test/counterfactual_mean_based/double_robustness_iate.jl @@ -59,10 +59,7 @@ function binary_outcome_binary_treatment_pb(;n=100) temp -= μy_fn(w, [0], [1])[1] IATE += temp*0.5*0.5*0.5 end - # Define the SCM - scm = StaticConfoundedModel([:Y], [:T₁, :T₂], [:W₁, :W₂, :W₃]) - - return dataset, scm, IATE + return dataset, IATE end @@ -120,10 +117,7 @@ function binary_outcome_categorical_treatment_pb(;n=100) temp -= μy_fn(w, (T₁=categorical(["CG"], levels=levels₁), T₂=categorical(["AT"], levels=levels₂)), Hmach)[1] IATE += temp*0.5*0.5*0.5 end - # Define the SCM - scm = StaticConfoundedModel([:Y], [:T₁, :T₂], [:W₁, :W₂, :W₃]) - - return dataset, scm, IATE + return dataset, IATE end @@ -168,121 +162,173 @@ function continuous_outcome_binary_treatment_pb(;n=100) temp -= μy_fn(w, [0], [1])[1] IATE += temp*0.5*0.5*0.5 end - - # Define the SCM - scm = StaticConfoundedModel([:Y], [:T₁, :T₂], [:W₁, :W₂, :W₃]) - - return dataset, scm, IATE + return dataset, IATE end @testset "Test Double Robustness IATE on binary_outcome_binary_treatment_pb" begin - dataset, scm, Ψ₀ = binary_outcome_binary_treatment_pb(n=10_000) + dataset, Ψ₀ = binary_outcome_binary_treatment_pb(n=10_000) Ψ = IATE( - scm, outcome=:Y, - treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), + treatment_values = ( + T₁=(case=true, control=false), + T₂=(case=true, control=false) + ), + treatment_confounders = ( + T₁=[:W₁, :W₂, :W₃], + T₂=[:W₁, :W₂, :W₃], + ) ) # When Q is misspecified but G is well specified - scm.Y.model = TreatmentTransformer() |> ConstantClassifier() - scm.T₁.model = LogisticClassifier(lambda=0) - scm.T₂.model = LogisticClassifier(lambda=0) + models = ( + Y = with_encoder(ConstantClassifier()), + T₁ = LogisticClassifier(lambda=0), + T₂ = LogisticClassifier(lambda=0), + ) + tmle = TMLEE(models) + ose = OSE(models) + naive = NAIVE(models.Y) + cache = Dict() - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0); - ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0); + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0); + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0); test_coverage(tmle_result, Ψ₀) test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-9) test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away - @test naive_plugin_estimate(Ψ) == 0 + naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) + @test naive_result == 0 # When Q is well specified but G is misspecified - scm.Y.model = TreatmentTransformer() |> LogisticClassifier(lambda=0) - scm.T₁.model = ConstantClassifier() - scm.T₂.model = ConstantClassifier() + models = ( + Y = with_encoder(LogisticClassifier(lambda=0)), + T₁ = ConstantClassifier(), + T₂ = ConstantClassifier(), + ) + tmle.models = models + ose.models = models + naive.model = models.Y - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0); - ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0); + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0); + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0); test_coverage(tmle_result, Ψ₀) test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-9) test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away - @test naive_plugin_estimate(Ψ) ≈ -0.0 atol=1e-1 + naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) + @test naive_result ≈ -0.0 atol=1e-1 end @testset "Test Double Robustness IATE on continuous_outcome_binary_treatment_pb" begin - dataset, scm,Ψ₀ = continuous_outcome_binary_treatment_pb(n=10_000) + dataset, Ψ₀ = continuous_outcome_binary_treatment_pb(n=10_000) Ψ = IATE( - scm, - outcome=:Y, - treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), + outcome = :Y, + treatment_values = ( + T₁=(case=true, control=false), + T₂=(case=true, control=false) + ), + treatment_confounders = ( + T₁=[:W₁, :W₂, :W₃], + T₂=[:W₁, :W₂, :W₃], + ) ) # When Q is misspecified but G is well specified - scm.Y.model = TreatmentTransformer() |> MLJModels.DeterministicConstantRegressor() - scm.T₁.model = LogisticClassifier(lambda=0) - scm.T₂.model = LogisticClassifier(lambda=0) + models = ( + Y = with_encoder(MLJModels.DeterministicConstantRegressor()), + T₁ = LogisticClassifier(lambda=0), + T₂ = LogisticClassifier(lambda=0), + ) + tmle = TMLEE(models) + ose = OSE(models) + naive = NAIVE(models.Y) + cache = Dict() - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) - ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0); + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0); test_coverage(tmle_result, Ψ₀) test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away - @test naive_plugin_estimate(Ψ) == 0 + naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) + @test naive_result == 0 # When Q is well specified but G is misspecified - scm.Y.model = TreatmentTransformer() |> cont_interacter - scm.T₁.model = ConstantClassifier() - scm.T₂.model = ConstantClassifier() + models = ( + Y = with_encoder(cont_interacter), + T₁ = ConstantClassifier(), + T₂ = ConstantClassifier(), + ) + tmle.models = models + ose.models = models + naive.model = models.Y - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0) - ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0); + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0); test_coverage(tmle_result, Ψ₀) test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) end @testset "Test Double Robustness IATE on binary_outcome_categorical_treatment_pb" begin - dataset, scm, Ψ₀ = binary_outcome_categorical_treatment_pb(n=30_000) + dataset, Ψ₀ = binary_outcome_categorical_treatment_pb(n=30_000) Ψ = IATE( - scm, outcome=:Y, - treatment=(T₁=(case="CC", control="CG"), T₂=(case="AT", control="AA")), + treatment_values= ( + T₁=(case="CC", control="CG"), + T₂=(case="AT", control="AA") + ), + treatment_confounders = ( + T₁=[:W₁, :W₂, :W₃], + T₂=[:W₁, :W₂, :W₃], + ) ) # When Q is misspecified but G is well specified - scm.Y.model = TreatmentTransformer() |> ConstantClassifier() - scm.T₁.model = LogisticClassifier(lambda=0) - scm.T₂.model = LogisticClassifier(lambda=0) + models = ( + Y = with_encoder(ConstantClassifier()), + T₁ = LogisticClassifier(lambda=0), + T₂ = LogisticClassifier(lambda=0) + ) + tmle = TMLEE(models) + ose = OSE(models) + naive = NAIVE(models.Y) + cache = Dict() - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0); - ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0); + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0); + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0); test_coverage(tmle_result, Ψ₀) test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away - @test naive_plugin_estimate(Ψ) == 0 + naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) + @test naive_result == 0 # When Q is well specified but G is misspecified - scm.Y.model = TreatmentTransformer() |> cat_interacter - scm.T₁.model = ConstantClassifier() - scm.T₂.model = ConstantClassifier() + models = ( + Y = with_encoder(cat_interacter), + T₁ = ConstantClassifier(), + T₂ = ConstantClassifier(), + ) + tmle.models = models + ose.models = models + naive.model = models.Y - tmle_result, fluctuation_mach = tmle!(Ψ, dataset, verbosity=0); - ose_result, fluctuation_mach = ose!(Ψ, dataset, verbosity=0); + tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0); + ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0); test_coverage(tmle_result, Ψ₀) test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) # The initial estimate is far away - @test naive_plugin_estimate(Ψ) ≈ -0.02 atol=1e-2 + naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) + @test naive_result ≈ -0.02 atol=1e-2 end From 8331876b5d207ca2e3387a9f5ca452387a3d3eda Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sat, 23 Sep 2023 21:47:54 +0100 Subject: [PATCH 076/151] fix non regression --- .../non_regression_test.jl | 21 +++---------------- 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index be4a69fd..e67ad5c9 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -31,7 +31,7 @@ using MLJBase resampling=nothing # No CV ps_lowerbound = 0.025 # Cutoff hardcoded in tmle3 weighted = false # Unweighted fluctuation - verbosity = 1 # No logs + verbosity = 0 # No logs tmle = TMLEE(models; resampling=resampling, ps_lowerbound=ps_lowerbound, @@ -39,9 +39,6 @@ using MLJBase ) tmle_result, cache = tmle(Ψ, dataset; verbosity=verbosity); - tmle_result - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=verbosity); - cache tmle.models = ( haz01 = with_encoder(LinearBinaryClassifier()), parity01 = LinearBinaryClassifier(fit_intercept=false) @@ -52,22 +49,10 @@ using MLJBase @test u ≈ -0.091821 atol = 1e-6 @test OneSampleZTest(tmle_result) isa OneSampleZTest - # OSE - ose = OSE(models; - resampling=resampling, - ps_lowerbound=ps_lowerbound - ) - ose_result, targeted_factors = ose(Ψ, dataset; verbosity=verbosity) - ose_result - - # CV-TMLE - tmle.resampling = StratifiedCV(nfolds=10) - cv_tmle_result, targeted_factors = tmle(Ψ, dataset; verbosity=verbosity) - cv_tmle_result - # Naive naive = NAIVE(models.haz01) - @test naive(Ψ, dataset) ≈ -0.150078 atol = 1e-6 + naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=verbosity) + @test naive_result ≈ -0.150078 atol = 1e-6 end From dceac151d2f2c0856f9f10c9416b88a8400332d0 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sat, 23 Sep 2023 21:58:33 +0100 Subject: [PATCH 077/151] fix more stuff --- test/composition.jl | 60 ++++++++++++------- .../estimators_and_estimates.jl | 10 +--- test/missing_management.jl | 4 +- 3 files changed, 41 insertions(+), 33 deletions(-) diff --git a/test/composition.jl b/test/composition.jl index a78e2c50..31b4dd92 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -8,7 +8,7 @@ using MLJLinearModels using TMLE using CategoricalArrays -function make_dataset_and_scm(;n=100) +function make_dataset(;n=100) rng = StableRNG(123) W = rand(rng, Uniform(), n) T = rand(rng, [0, 1], n) @@ -18,8 +18,7 @@ function make_dataset_and_scm(;n=100) W = W, T = categorical(T) ) - scm = StaticConfoundedModel(:Y, :T, :W) - return dataset, scm + return dataset end @testset "Test cov" begin @@ -33,23 +32,31 @@ end end @testset "Test composition CM(1) - CM(0) = ATE(1,0)" begin - dataset, scm = make_dataset_and_scm(;n=1000) + dataset = make_dataset(;n=1000) # Counterfactual Mean T = 1 CM₁ = CM( - scm, outcome = :Y, - treatment = (T=1,) + treatment_values = (T=1,), + treatment_confounders = (T=[:W],) ) - CM_tmle_result₁, _ = tmle!(CM₁, dataset, verbosity=0) - CM_ose_result₁, _ = ose!(CM₁, dataset, verbosity=0) + models = ( + Y = with_encoder(LinearRegressor()), + T = LogisticClassifier(lambda=0) + ) + tmle = TMLEE(models) + ose = OSE(models) + cache = Dict() + + CM_tmle_result₁, cache = tmle(CM₁, dataset; cache=cache, verbosity=0) + CM_ose_result₁, cache = ose(CM₁, dataset; cache=cache, verbosity=0) # Counterfactual Mean T = 0 CM₀ = CM( - scm, outcome = :Y, - treatment = (T=0,) + treatment_values = (T=0,), + treatment_confounders = (T=[:W],) ) - CM_tmle_result₀, _ = tmle!(CM₀, dataset, verbosity=0) - CM_ose_result₀, _ = ose!(CM₀, dataset, verbosity=0) + CM_tmle_result₀, cache = tmle(CM₀, dataset; cache=cache, verbosity=0) + CM_ose_result₀, cache = ose(CM₀, dataset; cache=cache, verbosity=0) # Composition of TMLE CM_result_composed_tmle = compose(-, CM_tmle_result₁, CM_tmle_result₀); @@ -57,12 +64,12 @@ end # Via ATE ATE₁₀ = ATE( - scm, outcome = :Y, - treatment = (T=(case=1, control=0),), + treatment_values = (T=(case=1, control=0),), + treatment_confounders = (T=[:W],) ) # Check composed TMLE - ATE_tmle_result₁₀, _ = tmle!(ATE₁₀, dataset, verbosity=0) + ATE_tmle_result₁₀, cache = tmle(ATE₁₀, dataset; cache=cache, verbosity=0) @test estimate(ATE_tmle_result₁₀) ≈ estimate(CM_result_composed_tmle) atol = 1e-7 # T Test composed_confint = collect(confint(OneSampleTTest(CM_result_composed_tmle))) @@ -76,7 +83,7 @@ end @test var(ATE_tmle_result₁₀) ≈ var(CM_result_composed_tmle) atol = 1e-3 # Check composed OSE - ATE_ose_result₁₀, _ = ose!(ATE₁₀, dataset, verbosity=0) + ATE_ose_result₁₀, cache = ose(ATE₁₀, dataset; cache=cache, verbosity=0) @test estimate(ATE_ose_result₁₀) ≈ estimate(CM_result_composed_ose) atol = 1e-7 # T Test composed_confint = collect(confint(OneSampleTTest(CM_result_composed_ose))) @@ -91,20 +98,27 @@ end end @testset "Test compose multidimensional function" begin - dataset, scm = make_dataset_and_scm(;n=1000) + dataset = make_dataset(;n=1000) + models = ( + Y = with_encoder(LinearRegressor()), + T = LogisticClassifier(lambda=0) + ) + tmle = TMLEE(models) + cache = Dict() + CM₁ = CM( - scm, outcome = :Y, - treatment = (T=1,) + treatment_values = (T=1,), + treatment_confounders = (T=[:W],) ) - CM_result₁, _ = tmle!(CM₁, dataset, verbosity=0) + CM_result₁, cache = tmle(CM₁, dataset; cache=cache, verbosity=0) CM₀ = CM( - scm, outcome = :Y, - treatment = (T=0,) + treatment_values = (T=0,), + treatment_confounders = (T=[:W],) ) - CM_result₀, _ = tmle!(CM₀, dataset, verbosity=0) + CM_result₀, cache = tmle(CM₀, dataset; cache=cache, verbosity=0) f(x, y) = [x^2 - y, x/y, 2x + 3y] CM_result_composed = compose(f, CM_result₁, CM_result₀) diff --git a/test/counterfactual_mean_based/estimators_and_estimates.jl b/test/counterfactual_mean_based/estimators_and_estimates.jl index c93387df..cc9ac9a2 100644 --- a/test/counterfactual_mean_based/estimators_and_estimates.jl +++ b/test/counterfactual_mean_based/estimators_and_estimates.jl @@ -44,14 +44,8 @@ function build_dataset_and_scm(;n=100) Y₂ = Y₂, Y₃ = Y₃ ) - scm = SCM( - SE(:T₁, [:W₁, :W₂], LogisticClassifier(lambda=0)), - SE(:T₂, [:W₁, :W₂], LogisticClassifier(lambda=0)), - SE(:Y₁, [:T₁, :T₂, :W₁, :W₂, :C₁], TreatmentTransformer() |> LinearRegressor()), - SE(:Y₂, [:T₂, :W₂], TreatmentTransformer() |> LinearRegressor()), - SE(:Y₃, [:T₁, :T₂, :W₁, :W₂, :C₁], TreatmentTransformer() |> LinearRegressor()), - ) - return dataset, scm + + return dataset end diff --git a/test/missing_management.jl b/test/missing_management.jl index 6125deb8..7351d2f4 100644 --- a/test/missing_management.jl +++ b/test/missing_management.jl @@ -47,9 +47,9 @@ end treatment_confounders=(T=[:W],)) models=(Y=with_encoder(LinearRegressor()), T=LogisticClassifier(lambda=0)) tmle = TMLEE(models) - tmle_result, fluctuation_mach = tmle(Ψ, dataset; verbosity=0) + tmle_result, cache = tmle(Ψ, dataset; verbosity=0) test_coverage(tmle_result, 1) - test_fluct_decreases_risk(Ψ, fluctuation_mach) + test_fluct_decreases_risk(cache) end end From 2e99ee70630230a9008f7d3eae5db0213202bb74 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 24 Sep 2023 11:36:08 +0100 Subject: [PATCH 078/151] on my way to fixed tests --- src/counterfactual_mean_based/estimators.jl | 18 +- src/estimators.jl | 26 +- src/utils.jl | 3 + .../estimators_and_estimates.jl | 353 +++--------------- test/estimators_and_estimates.jl | 4 - 5 files changed, 76 insertions(+), 328 deletions(-) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 18f095ad..4da0d84b 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -7,6 +7,9 @@ struct CMRelevantFactorsEstimator <: Estimator models end +CMRelevantFactorsEstimator(;models, resampling=nothing) = + CMRelevantFactorsEstimator(resampling, models) + key(estimator::CMRelevantFactorsEstimator) = (CMRelevantFactorsEstimator, estimator.resampling, estimator.models) @@ -25,13 +28,14 @@ function get_train_validation_indices(resampling::ResamplingStrategy, factors, d end function (estimator::CMRelevantFactorsEstimator)(estimand, dataset; cache=Dict(), verbosity=1) - cache_key = key(estimand, estimator) - if haskey(cache, cache_key) - estimate = cache[cache_key] - verbosity > 0 && @info(string("Reusing estimate for: ", string_repr(estimand))) - return estimate + if haskey(cache, estimand) + old_estimator, estimate = cache[estimand] + if key(old_estimator) == key(estimator) + verbosity > 0 && @info(reuse_string(estimand)) + return estimate + end end - verbosity > 0 && @info(string("Estimating: ", string_repr(estimand))) + verbosity > 0 && @info(fit_string(estimand)) models = estimator.models # Get train validation indices train_validation_indices = TMLE.get_train_validation_indices(estimator.resampling, estimand, dataset) @@ -55,7 +59,7 @@ function (estimator::CMRelevantFactorsEstimator)(estimand, dataset; cache=Dict() # Build estimate estimate = MLCMRelevantFactors(estimand, outcome_mean_estimate, propensity_score_estimate) # Update cache - cache[(estimand, estimator)] = estimate + cache[estimand] = estimator => estimate return estimate end diff --git a/src/estimators.jl b/src/estimators.jl index 6857a1e0..5ac87a28 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -10,11 +10,12 @@ end function (estimator::MLConditionalDistributionEstimator)(estimand, dataset; cache=Dict(), verbosity=1) # Lookup in cache - cache_key = key(estimand, estimator) - if haskey(cache, cache_key) - estimate = cache[cache_key] - verbosity > 0 && @info(string("Reusing estimate for: ", string_repr(estimand))) - return estimate + if haskey(cache, estimand) + old_estimator, estimate = cache[estimand] + if key(old_estimator) == key(estimator) + verbosity > 0 && @info(reuse_string(estimand)) + return estimate + end end verbosity > 0 && @info(string("Estimating: ", string_repr(estimand))) # Otherwise estimate @@ -27,7 +28,7 @@ function (estimator::MLConditionalDistributionEstimator)(estimand, dataset; cach # Build estimate estimate = MLConditionalDistribution(estimand, mach) # Update cache - cache[cache_key] = estimate + cache[estimand] = estimator => estimate return estimate end @@ -46,11 +47,12 @@ end function (estimator::SampleSplitMLConditionalDistributionEstimator)(estimand, dataset; cache=Dict(), verbosity=1) # Lookup in cache - cache_key = key(estimand, estimator) - if haskey(cache, cache_key) - estimate = cache[cache_key] - verbosity > 0 && @info(string("Reusing estimate for: ", string_repr(estimand))) - return estimate + if haskey(cache, estimand) + old_estimator, estimate = cache[estimand] + if key(old_estimator) == key(estimator) + verbosity > 0 && @info(reuse_string(estimand)) + return estimate + end end # Otherwise estimate verbosity > 0 && @info(string("Estimating: ", string_repr(estimand))) @@ -70,7 +72,7 @@ function (estimator::SampleSplitMLConditionalDistributionEstimator)(estimand, da # Build estimate estimate = SampleSplitMLConditionalDistribution(estimand, estimator.train_validation_indices, machines) # Update cache - cache[cache_key] = estimate + cache[estimand] = estimator => estimate return estimate end diff --git a/src/utils.jl b/src/utils.jl index d7b05a56..c03a8422 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,6 +2,9 @@ ## General Utilities ############################################################################### +reuse_string(estimand) = string("Reusing estimate for: ", string_repr(estimand)) +fit_string(estimand) = string("Estimating: ", string_repr(estimand)) + key(estimand, estimator) = (key(estimand), key(estimator)) unique_sorted_tuple(iter) = Tuple(sort(unique(Symbol(x) for x in iter))) diff --git a/test/counterfactual_mean_based/estimators_and_estimates.jl b/test/counterfactual_mean_based/estimators_and_estimates.jl index cc9ac9a2..0bd7085d 100644 --- a/test/counterfactual_mean_based/estimators_and_estimates.jl +++ b/test/counterfactual_mean_based/estimators_and_estimates.jl @@ -12,319 +12,62 @@ using LogExpFunctions include(joinpath(dirname(@__DIR__), "helper_fns.jl")) -""" -Results derived by hand for this dataset: - -- ATE(y₁, T₁, W₁, W₂, C₁) = 1 - 4E[C₁] = 1 - 2 = -1 -- ATE(y₁, T₂, W₁, W₂, C₁) = E[W₂] = 0.5 -- ATE(y₁, (T₁, T₂), W₁, W₂, C₁) = E[-4C₁ + 1 + W₂] = -0.5 -- ATE(y₂, T₂, W₁, W₂, C₁) = 10 -""" -function build_dataset_and_scm(;n=100) - rng = StableRNG(123) - # Confounders - W₁ = rand(rng, Uniform(), n) - W₂ = rand(rng, Uniform(), n) - # Covariates - C₁ = rand(rng, n) - # Treatment | Confounders - T₁ = rand(rng, Uniform(), n) .< logistic.(0.5sin.(W₁) .- 1.5W₂) - T₂ = rand(rng, Uniform(), n) .< logistic.(-3W₁ - 1.5W₂) - # outcome | Confounders, Covariates, Treatments - Y₁ = 1 .+ 2W₁ .+ 3W₂ .- 4C₁.*T₁ .+ T₁ + T₂.*W₂ .+ rand(rng, Normal(0, 0.1), n) - Y₂ = 1 .+ 10T₂ .+ W₂ .+ rand(rng, Normal(0, 0.1), n) - Y₃ = 1 .+ 2W₁ .+ 3W₂ .- 4C₁.*T₁ .- 2T₂.*T₁.*W₂ .+ rand(rng, Normal(0, 0.1), n) - dataset = ( - T₁ = categorical(T₁), - T₂ = categorical(T₂), - W₁ = W₁, - W₂ = W₂, - C₁ = C₁, - Y₁ = Y₁, - Y₂ = Y₂, - Y₃ = Y₃ - ) - - return dataset -end - - -table_types = (Tables.columntable, DataFrame) - -@testset "Test MLCMRelevantFactors" begin +@testset "Test CMRelevantFactorsEstimator" begin n = 100 W = rand(Normal(), n) T₁ = rand(n) .< logistic.(1 .- W) Y = T₁ .+ W .+ rand(n) dataset = DataFrame(Y=Y, W=W, T₁=categorical(T₁, ordered=true)) - models = (Y=with_encoder(LinearRegressor()), T₁=LinearBinaryClassifier()) - estimand = TMLE.CMRelevantFactors( - scm, - outcome_mean = TMLE.ConditionalDistribution(scm, :Y, Set([:T₁, :W])), - propensity_score = (TMLE.ConditionalDistribution(scm, :T₁, Set([:W])),) - ) - # No resampling - resampling = nothing - estimate = Distributions.estimate(estimand, resampling, models, dataset; - factors_cache=nothing, - verbosity=0 - ) - outcome_model_features = [:T₁, :W] - @test predict(estimate.outcome_mean, dataset) == predict(estimate.outcome_mean.machine, dataset[!, outcome_model_features]) - @test predict(estimate.propensity_score[1], dataset).prob_given_ref == predict(estimate.propensity_score[1].machine, dataset[!, [:W]]).prob_given_ref - - # No resampling - nfolds = 3 - resampling = CV(nfolds=nfolds) - estimate = Distributions.estimate(estimand, resampling, models, dataset; - factors_cache=nothing, - verbosity=0 - ) - train_validation_indices = estimate.outcome_mean.train_validation_indices - # Check all models share the same train/validation indices - @test train_validation_indices == estimate.propensity_score[1].train_validation_indices - outcome_model_features = [:T₁, :W] - ŷ = predict(estimate.outcome_mean, dataset) - t̂ = predict(estimate.propensity_score[1], dataset) - for foldid in 1:nfolds - train, val = train_validation_indices[foldid] - # The predictions on validation samples are made from - # the machine trained on the train sample - ŷfold = predict(estimate.outcome_mean.machines[foldid], dataset[val, outcome_model_features]) - @test ŷ[val] == ŷfold - t̂fold = predict(estimate.propensity_score[1].machines[foldid], dataset[val, [:W]]) - @test t̂fold.prob_given_ref == t̂[val].prob_given_ref - end + # Estimand + Q = TMLE.ConditionalDistribution(:Y, [:T₁, :W]) + G = (TMLE.ConditionalDistribution(:T₁, [:W]),) + η = TMLE.CMRelevantFactors(outcome_mean=Q, propensity_score=G) + # Estimator + models = ( + Y = with_encoder(LinearRegressor()), + T₁ = LogisticClassifier() + ) + η̂ = TMLE.CMRelevantFactorsEstimator(models=models) + # Estimate + fit_log = ( + (:info, TMLE.fit_string(η)), + (:info, TMLE.fit_string(G[1])), + (:info, TMLE.fit_string(Q)) + ) + cache = Dict() + η̂ₙ = @test_logs fit_log... η̂(η, dataset; cache=cache, verbosity=1) + # Test both sub estimands have been fitted + @test η̂ₙ.outcome_mean isa TMLE.MLConditionalDistribution + @test fitted_params(η̂ₙ.outcome_mean.machine) isa NamedTuple + @test η̂ₙ.propensity_score[1] isa TMLE.MLConditionalDistribution + @test fitted_params(η̂ₙ.propensity_score[1].machine) isa NamedTuple + + # Both models unchanged, η̂ₙ is fully reused + new_models = ( + Y = with_encoder(LinearRegressor()), + T₁ = LogisticClassifier() + ) + new_η̂ = TMLE.CMRelevantFactorsEstimator(models=new_models) + @test TMLE.key(η, new_η̂) == TMLE.key(η, η̂) + full_reuse_log = (:info, TMLE.reuse_string(η)) + @test_logs full_reuse_log new_η̂(η, dataset; cache=cache, verbosity=1) + + # Changing one model, only the other one is refitted + new_models = ( + Y = with_encoder(LinearRegressor()), + T₁ = LogisticClassifier(fit_intercept=false) + ) + new_η̂ = TMLE.CMRelevantFactorsEstimator(models=new_models) + @test TMLE.key(η, new_η̂) != TMLE.key(η, η̂) + partial_reuse_log = ( + (:info, TMLE.fit_string(η)), + (:info, TMLE.fit_string(G[1])), + (:info, TMLE.reuse_string(Q)) + ) + @test_logs partial_reuse_log... new_η̂(η, dataset; cache=cache, verbosity=1) end -@testset "Test Warm restart: ATE single treatment, $tt" for tt in table_types - dataset, scm = build_dataset_and_scm(;n=50_000) - dataset = tt(dataset) - # Define the estimand of interest - Ψ = ATE( - scm, - outcome=:Y₁, - treatment=(T₁=(case=true, control=false),), - ) - # Run TMLE - log_sequence = ( - (:info, "Fitting the required equations..."), - (:info, "Fitting Structural Equation corresponding to variable T₁."), - (:info, "Fitting Structural Equation corresponding to variable Y₁."), - (:info, "Performing TMLE..."), - (:info, "Done.") - ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset; verbosity=1); - # The TMLE covers the ground truth but the initial estimate does not - Ψ₀ = -1 - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - - # Update the treatment specification - # Nuisance estimands should not fitted again - Ψ = ATE( - scm, - outcome=:Y₁, - treatment=(T₁=(case=false, control=true),), - ) - log_sequence = ( - (:info, "Fitting the required equations..."), - (:info, "Performing TMLE..."), - (:info, "Done.") - ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); - Ψ₀ = 1 - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - - # New marginal treatment estimation: T₂ - # This will trigger fit of the corresponding equations - Ψ = ATE( - scm, - outcome=:Y₁, - treatment=(T₂=(case=true, control=false),), - ) - log_sequence = ( - (:info, "Fitting the required equations..."), - (:info, "Fitting Structural Equation corresponding to variable T₂."), - (:info, "Fitting Structural Equation corresponding to variable Y₁."), - (:info, "Performing TMLE..."), - (:info, "Done.") - ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); - Ψ₀ = 0.5 - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - - # Change the outcome: Y₂ - # This will trigger the fit for the corresponding equation - Ψ = ATE( - scm, - outcome=:Y₂, - treatment=(T₂=(case=true, control=false),), - ) - log_sequence = ( - (:info, "Fitting the required equations..."), - (:info, "Fitting Structural Equation corresponding to variable Y₂."), - (:info, "Performing TMLE..."), - (:info, "Done.") - ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); - Ψ₀ = 10 - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - - # New estimand with 2 treatment variables - Ψ = ATE( - scm, - outcome=:Y₁, - treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), - ) - log_sequence = ( - (:info, "Fitting the required equations..."), - (:info, "Fitting Structural Equation corresponding to variable Y₁."), - (:info, "Performing TMLE..."), - (:info, "Done.") - ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); - Ψ₀ = -0.5 - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - - # Switching the T₂ setting - Ψ = ATE( - scm, - outcome=:Y₁, - treatment=(T₁=(case=true, control=false), T₂=(case=false, control=true)), - ) - log_sequence = ( - (:info, "Fitting the required equations..."), - (:info, "Performing TMLE..."), - (:info, "Done.") - ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); - Ψ₀ = -1.5 - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) -end - -@testset "Test Warm restart: CM, $tt" for tt in table_types - dataset, scm = build_dataset_and_scm(;n=50_000) - dataset = tt(dataset) - - Ψ = CM( - scm, - outcome=:Y₁, - treatment=(T₁=true, T₂=true), - ) - log_sequence = ( - (:info, "Fitting the required equations..."), - (:info, "Fitting Structural Equation corresponding to variable T₁."), - (:info, "Fitting Structural Equation corresponding to variable T₂."), - (:info, "Fitting Structural Equation corresponding to variable Y₁."), - (:info, "Performing TMLE..."), - (:info, "Done.") - ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); - Ψ₀ = 3 - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - - # Let's switch case and control for T₂ - Ψ = CM( - scm, - outcome=:Y₁, - treatment=(T₁=true, T₂=false), - ) - log_sequence = ( - (:info, "Fitting the required equations..."), - (:info, "Performing TMLE..."), - (:info, "Done.") - ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); - Ψ₀ = 2.5 - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - - # Change the outcome - Ψ = CM( - scm, - outcome=:Y₂, - treatment=(T₂=false, ), - ) - log_sequence = ( - (:info, "Fitting the required equations..."), - (:info, "Fitting Structural Equation corresponding to variable Y₂."), - (:info, "Performing TMLE..."), - (:info, "Done.") - ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); - Ψ₀ = 1.5 - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - - # Change the treatment model - scm.T₂.model = LogisticClassifier(lambda=0.01) - log_sequence = ( - (:info, "Fitting the required equations..."), - (:info, "Fitting Structural Equation corresponding to variable T₂."), - (:info, "Performing TMLE..."), - (:info, "Done.") - ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) -end - -@testset "Test Warm restart: pairwise IATE, $tt" for tt in table_types - dataset, scm = build_dataset_and_scm(;n=50_000) - dataset = tt(dataset) - Ψ = IATE( - scm, - outcome=:Y₃, - treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)), - ) - log_sequence = ( - (:info, "Fitting the required equations..."), - (:info, "Fitting Structural Equation corresponding to variable T₁."), - (:info, "Fitting Structural Equation corresponding to variable T₂."), - (:info, "Fitting Structural Equation corresponding to variable Y₃."), - (:info, "Performing TMLE..."), - (:info, "Done.") - ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); - Ψ₀ = -1 - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - - # Changing some equations - scm.Y₃.model = TreatmentTransformer() |> RidgeRegressor(lambda=1e-5) - scm.T₂.model = LogisticClassifier(lambda=1e-5) - - log_sequence = ( - (:info, "Fitting the required equations..."), - (:info, "Fitting Structural Equation corresponding to variable T₂."), - (:info, "Fitting Structural Equation corresponding to variable Y₃."), - (:info, "Performing TMLE..."), - (:info, "Done.") - ) - tmle_result, fluctuation_mach = @test_logs log_sequence... tmle!(Ψ, dataset, verbosity=1); - test_coverage(tmle_result, Ψ₀) - test_fluct_decreases_risk(Ψ, fluctuation_mach) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) -end end diff --git a/test/estimators_and_estimates.jl b/test/estimators_and_estimates.jl index 9244b33f..e09cc642 100644 --- a/test/estimators_and_estimates.jl +++ b/test/estimators_and_estimates.jl @@ -36,8 +36,6 @@ reuse_log = string("Reusing estimate for: ", TMLE.string_repr(estimand)) # Changing the model leads to refit new_estimator = TMLE.MLConditionalDistributionEstimator(LinearRegressor(fit_intercept=false)) @test_logs (:info, fit_log) new_estimator(estimand, dataset; cache=cache, verbosity=verbosity) - # The cache contains two estimators for the estimand - @test length(cache) == 2 end @testset "Test SampleSplitMLConditionalDistributionEstimator" begin @@ -84,8 +82,6 @@ end train_validation_indices ) @test_logs (:info, fit_log) new_estimator(estimand, dataset; cache=cache, verbosity=verbosity) - # The cache contains 3 estimators for the estimand - @test length(cache) == 3 end @testset "Test compute_offset MLConditionalDistributionEstimator" begin From d43e37b09638ea477d6eefb02d3fde794ce050ec Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 24 Sep 2023 11:57:12 +0100 Subject: [PATCH 079/151] fixed tests --- test/estimators_and_estimates.jl | 37 +++++++++++++++++--------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/test/estimators_and_estimates.jl b/test/estimators_and_estimates.jl index e09cc642..4fd0fe65 100644 --- a/test/estimators_and_estimates.jl +++ b/test/estimators_and_estimates.jl @@ -10,14 +10,14 @@ using LogExpFunctions verbosity = 1 n = 100 -X, y = make_regression(n, 4) -dataset = DataFrame(Y=y, T₁=X.x1, T₂=X.x2, W=X.x3) -estimand = ConditionalDistribution(:Y, [:W, :T₁, :T₂]) +X, y = make_moons(n) +dataset = DataFrame(Y=y, X₁=X.x1, X₂=X.x2) +estimand = ConditionalDistribution(:Y, [:X₁, :X₂]) fit_log = string("Estimating: ", TMLE.string_repr(estimand)) reuse_log = string("Reusing estimate for: ", TMLE.string_repr(estimand)) @testset "Test MLConditionalDistributionEstimator" begin - estimator = TMLE.MLConditionalDistributionEstimator(LinearRegressor()) + estimator = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier()) # Fitting with no cache cache = Dict() estimate = @test_logs (:info, fit_log) estimator(estimand, dataset; cache=cache, verbosity=verbosity) @@ -25,25 +25,27 @@ reuse_log = string("Reusing estimate for: ", TMLE.string_repr(estimand)) @test estimate isa TMLE.MLConditionalDistribution @test fitted_params(estimate.machine).features == expected_features ŷ = predict(estimate, dataset) - @test ŷ == predict(estimate.machine, dataset[!, expected_features]) + mach_ŷ = predict(estimate.machine, dataset[!, expected_features]) + @test all(ŷ[i].prob_given_ref == mach_ŷ[i].prob_given_ref for i in eachindex(ŷ)) μ̂ = TMLE.expected_value(estimate, dataset) - @test μ̂ == mean.(ŷ) - @test_skip all(0. <= x <= 1. for x in TMLE.likelihood(estimate, dataset)) # The pdf is not necessarily between 0 and 1 + @test μ̂ == [ŷ[i].prob_given_ref[2] for i in eachindex(ŷ)] + @test all(0. <= x <= 1. for x in TMLE.likelihood(estimate, dataset)) # The pdf is not necessarily between 0 and 1 # Uses the cache instead of fitting - new_estimator = TMLE.MLConditionalDistributionEstimator(LinearRegressor()) + new_estimator = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier()) @test TMLE.key(new_estimator) == TMLE.key(estimator) @test_logs (:info, reuse_log) estimator(estimand, dataset; cache=cache, verbosity=verbosity) # Changing the model leads to refit - new_estimator = TMLE.MLConditionalDistributionEstimator(LinearRegressor(fit_intercept=false)) + new_estimator = TMLE.MLConditionalDistributionEstimator(LinearBinaryClassifier(fit_intercept=false)) + @test TMLE.key(new_estimator) != TMLE.key(estimator) @test_logs (:info, fit_log) new_estimator(estimand, dataset; cache=cache, verbosity=verbosity) end @testset "Test SampleSplitMLConditionalDistributionEstimator" begin nfolds = 3 train_validation_indices = collect(MLJBase.train_test_pairs(CV(nfolds=nfolds), 1:n, dataset)) - model = LinearRegressor() + model = LinearBinaryClassifier() estimator = TMLE.SampleSplitMLConditionalDistributionEstimator( - LinearRegressor(), + model, train_validation_indices ) cache = Dict() @@ -58,27 +60,28 @@ end # The predictions on validation samples are made from # the machine trained on the train sample ŷfold = predict(estimate.machines[foldid], dataset[val, expected_features]) - @test ŷ[val] == ŷfold - @test μ̂[val] == mean.(ŷfold) + @test [ŷᵢ.prob_given_ref for ŷᵢ ∈ ŷ[val]] == [ŷᵢ.prob_given_ref for ŷᵢ ∈ ŷfold] + @test μ̂[val] == [ŷᵢ.prob_given_ref[2] for ŷᵢ ∈ ŷfold] end - @test_skip all(0. <= x <= 1. for x in TMLE.likelihood(estimate, dataset)) + @test all(0. <= x <= 1. for x in TMLE.likelihood(estimate, dataset)) # Uses the cache instead of fitting new_estimator = TMLE.SampleSplitMLConditionalDistributionEstimator( - LinearRegressor(), + LinearBinaryClassifier(), train_validation_indices ) @test TMLE.key(new_estimator) == TMLE.key(estimator) @test_logs (:info, reuse_log) estimator(estimand, dataset;cache=cache, verbosity=verbosity) # Changing the model leads to refit + new_model = LinearBinaryClassifier(fit_intercept=false) new_estimator = TMLE.SampleSplitMLConditionalDistributionEstimator( - LinearRegressor(fit_intercept=false), + new_model, train_validation_indices ) @test_logs (:info, fit_log) new_estimator(estimand, dataset; cache=cache, verbosity=verbosity) # Changing the train/validation splits leads to refit train_validation_indices = collect(MLJBase.train_test_pairs(CV(nfolds=4), 1:n, dataset)) new_estimator = TMLE.SampleSplitMLConditionalDistributionEstimator( - LinearRegressor(), + new_model, train_validation_indices ) @test_logs (:info, fit_log) new_estimator(estimand, dataset; cache=cache, verbosity=verbosity) From 9afac0552136a4e75d69410730555e13aca32bda Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 24 Sep 2023 14:38:56 +0100 Subject: [PATCH 080/151] add dr test for cv tmle and cv ose --- .../double_robustness_ate.jl | 141 +++++------------- .../double_robustness_iate.jl | 91 ++++------- .../estimators_and_estimates.jl | 11 +- test/helper_fns.jl | 23 ++- 4 files changed, 94 insertions(+), 172 deletions(-) diff --git a/test/counterfactual_mean_based/double_robustness_ate.jl b/test/counterfactual_mean_based/double_robustness_ate.jl index 6a2eacc2..d0c75920 100644 --- a/test/counterfactual_mean_based/double_robustness_ate.jl +++ b/test/counterfactual_mean_based/double_robustness_ate.jl @@ -117,18 +117,12 @@ end Y = with_encoder(MLJModels.DeterministicConstantRegressor()), T = LogisticClassifier(lambda=0) ) - tmle = TMLEE(models) - ose = OSE(models) - naive = NAIVE(models.Y) - cache = Dict() - tmle_result, cache = tmle(Ψ, dataset, cache=cache, verbosity=0) - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) - test_coverage(tmle_result, Ψ₀) - test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(cache) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) + test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) # The initial estimate is far away + naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) @test naive_result == 0 @@ -137,16 +131,9 @@ end Y = with_encoder(TreatmentTransformer() |> LinearRegressor()), T = ConstantClassifier() ) - tmle.models = models - ose.models = models - naive.model = models.Y - - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) - test_coverage(tmle_result, Ψ₀) - test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(cache) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) end @testset "Test Double Robustness ATE on binary_outcome_binary_treatment_pb" begin @@ -161,19 +148,12 @@ end Y = with_encoder(ConstantClassifier()), T = LogisticClassifier(lambda=0) ) - tmle = TMLEE(models) - ose = OSE(models) - naive = NAIVE(models.Y) - cache = Dict() - - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) - test_coverage(tmle_result, Ψ₀) - test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(cache) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-6) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-6) + test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) # The initial estimate is far away + naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) @test naive_result == 0 # When Q is well specified but G is misspecified @@ -181,16 +161,9 @@ end Y = with_encoder(LogisticClassifier(lambda=0)), T = ConstantClassifier() ) - tmle.models = models - ose.models = models - naive.model = models.Y - - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) - test_coverage(tmle_result, Ψ₀) - test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(cache) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-6) + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-6) end @@ -206,19 +179,13 @@ end Y = with_encoder(MLJModels.DeterministicConstantRegressor()), T = LogisticClassifier(lambda=0) ) - tmle = TMLEE(models) - ose = OSE(models) - naive = NAIVE(models.Y) - cache = Dict() + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) - test_coverage(tmle_result, Ψ₀) - test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(cache) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) + test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) # The initial estimate is far away + naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) @test naive_result == 0 @@ -227,16 +194,9 @@ end Y = with_encoder(LinearRegressor()), T = ConstantClassifier() ) - tmle.models = models - ose.models = models - naive.model = models.Y - - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) - test_coverage(tmle_result, Ψ₀) - test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(cache) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) end @testset "Test Double Robustness ATE with two treatment variables" begin @@ -259,34 +219,19 @@ end T₁ = LogisticClassifier(lambda=0), T₂ = LogisticClassifier(lambda=0) ) - tmle = TMLEE(models) - ose = OSE(models) - naive = NAIVE(models.Y) - cache = Dict() - - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) - test_coverage(tmle_result, ATE₁₁₋₀₁) - test_coverage(ose_result, ATE₁₁₋₀₁) - test_fluct_decreases_risk(cache) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, ATE₁₁₋₀₁, dataset; verbosity=0) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) + test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) # When Q is well specified but G is misspecified models = ( Y = with_encoder(LinearRegressor()), T₁ = ConstantClassifier(), T₂ = ConstantClassifier() ) - tmle.models = models - ose.models = models - naive.model = models.Y - - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) - test_coverage(tmle_result, ATE₁₁₋₀₁) - test_coverage(ose_result, ATE₁₁₋₀₁) - test_fluct_decreases_risk(cache) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, ATE₁₁₋₀₁, dataset; verbosity=0) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) # Test second ATE, two treatment varies Ψ = ATE( @@ -300,13 +245,10 @@ end T₂ = [:W₁, :W₂], ) ) - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) - test_coverage(tmle_result, ATE₁₁₋₀₀) - test_coverage(ose_result, ATE₁₁₋₀₀) - test_fluct_decreases_risk(cache) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, ATE₁₁₋₀₀, dataset; verbosity=0) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) + test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) # When Q is well specified but G is misspecified models = ( @@ -314,16 +256,9 @@ end T₁ = LogisticClassifier(lambda=0), T₂ = LogisticClassifier(lambda=0), ) - tmle.models = models - ose.models = models - naive.model = models.Y - - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0) - test_coverage(tmle_result, ATE₁₁₋₀₀) - test_coverage(ose_result, ATE₁₁₋₀₀) - test_fluct_decreases_risk(cache) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, ATE₁₁₋₀₀, dataset; verbosity=0) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) end end; diff --git a/test/counterfactual_mean_based/double_robustness_iate.jl b/test/counterfactual_mean_based/double_robustness_iate.jl index f91f2874..6456a148 100644 --- a/test/counterfactual_mean_based/double_robustness_iate.jl +++ b/test/counterfactual_mean_based/double_robustness_iate.jl @@ -184,19 +184,12 @@ end T₁ = LogisticClassifier(lambda=0), T₂ = LogisticClassifier(lambda=0), ) - tmle = TMLEE(models) - ose = OSE(models) - naive = NAIVE(models.Y) - cache = Dict() - - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0); - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0); - test_coverage(tmle_result, Ψ₀) - test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(cache) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-9) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-9) + test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) # The initial estimate is far away + naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) @test naive_result == 0 @@ -206,18 +199,12 @@ end T₁ = ConstantClassifier(), T₂ = ConstantClassifier(), ) - tmle.models = models - ose.models = models - naive.model = models.Y - - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0); - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0); - test_coverage(tmle_result, Ψ₀) - test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(cache) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-9) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-9) + test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) # The initial estimate is far away + naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) @test naive_result ≈ -0.0 atol=1e-1 end @@ -241,19 +228,13 @@ end T₁ = LogisticClassifier(lambda=0), T₂ = LogisticClassifier(lambda=0), ) - tmle = TMLEE(models) - ose = OSE(models) - naive = NAIVE(models.Y) - cache = Dict() - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0); - test_coverage(tmle_result, Ψ₀) - test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(cache) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) + test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) # The initial estimate is far away + naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) @test naive_result == 0 @@ -263,16 +244,9 @@ end T₁ = ConstantClassifier(), T₂ = ConstantClassifier(), ) - tmle.models = models - ose.models = models - naive.model = models.Y - - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0) - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0); - test_coverage(tmle_result, Ψ₀) - test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(cache) - test_mean_inf_curve_almost_zero(tmle_result; atol=1e-10) + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) end @@ -295,18 +269,11 @@ end T₁ = LogisticClassifier(lambda=0), T₂ = LogisticClassifier(lambda=0) ) - tmle = TMLEE(models) - ose = OSE(models) - naive = NAIVE(models.Y) - cache = Dict() - - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0); - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0); - test_coverage(tmle_result, Ψ₀) - test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(cache) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) + test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) # The initial estimate is far away + naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) @test naive_result == 0 @@ -316,17 +283,11 @@ end T₁ = ConstantClassifier(), T₂ = ConstantClassifier(), ) - tmle.models = models - ose.models = models - naive.model = models.Y - - tmle_result, cache = tmle(Ψ, dataset; cache=cache, verbosity=0); - ose_result, cache = ose(Ψ, dataset; cache=cache, verbosity=0); - test_coverage(tmle_result, Ψ₀) - test_coverage(ose_result, Ψ₀) - test_fluct_decreases_risk(cache) - test_fluct_mean_inf_curve_lower_than_initial(tmle_result, ose_result) + dr_estimators = double_robust_estimators(models) + results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) + test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) # The initial estimate is far away + naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) @test naive_result ≈ -0.02 atol=1e-2 end diff --git a/test/counterfactual_mean_based/estimators_and_estimates.jl b/test/counterfactual_mean_based/estimators_and_estimates.jl index 0bd7085d..d5f371fa 100644 --- a/test/counterfactual_mean_based/estimators_and_estimates.jl +++ b/test/counterfactual_mean_based/estimators_and_estimates.jl @@ -9,8 +9,7 @@ using Distributions using MLJLinearModels using CategoricalArrays using LogExpFunctions - -include(joinpath(dirname(@__DIR__), "helper_fns.jl")) +using MLJBase @testset "Test CMRelevantFactorsEstimator" begin n = 100 @@ -66,9 +65,15 @@ include(joinpath(dirname(@__DIR__), "helper_fns.jl")) ) @test_logs partial_reuse_log... new_η̂(η, dataset; cache=cache, verbosity=1) + # Adding a resampling strategy + resampled_η̂ = TMLE.CMRelevantFactorsEstimator(models=new_models, resampling=CV(nfolds=3)) + @test TMLE.key(η, new_η̂) != TMLE.key(η, resampled_η̂) + η̂ₙ = @test_logs fit_log... resampled_η̂(η, dataset; cache=cache, verbosity=1) + @test length(η̂ₙ.outcome_mean.machines) == 3 + @test length(η̂ₙ.propensity_score[1].machines) == 3 + @test η̂ₙ.outcome_mean.train_validation_indices == η̂ₙ.propensity_score[1].train_validation_indices end - end true \ No newline at end of file diff --git a/test/helper_fns.jl b/test/helper_fns.jl index c9a2a6bf..aaf3ab92 100644 --- a/test/helper_fns.jl +++ b/test/helper_fns.jl @@ -58,4 +58,25 @@ test_mean_inf_curve_almost_zero(tmle_result::TMLE.EICEstimate; atol=1e-10) = @te This cqnnot be guaranteed in general since a well specified maximum likelihood estimator also solves the score equation. """ -test_fluct_mean_inf_curve_lower_than_initial(tmle_result::TMLE.TMLEstimate, ose_result::TMLE.OSEstimate) = @test abs(mean(tmle_result.IC)) < abs(mean(ose_result.IC)) \ No newline at end of file +test_fluct_mean_inf_curve_lower_than_initial(tmle_result::TMLE.TMLEstimate, ose_result::TMLE.OSEstimate) = @test abs(mean(tmle_result.IC)) < abs(mean(ose_result.IC)) + +double_robust_estimators(models; resampling=CV(nfolds=3)) = ( + tmle = TMLEE(models), + ose = OSE(models), + cv_tmle = TMLEE(models, resampling=resampling), + cv_ose = TMLEE(models, resampling=resampling), +) + +function test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) + cache = Dict() + results = [] + for estimator ∈ dr_estimators + result, cache = estimator(Ψ, dataset, cache=cache, verbosity=verbosity) + push!(results, result) + test_coverage(result, Ψ₀) + if estimator isa TMLEE && estimator.resampling === nothing + test_fluct_decreases_risk(cache) + end + end + return NamedTuple{keys(dr_estimators)}(results), cache +end \ No newline at end of file From 53910290e545b3062e5013bb7db7d41a1397a7ca Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 24 Sep 2023 15:42:26 +0100 Subject: [PATCH 081/151] add type to method --- src/counterfactual_mean_based/estimators.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 4da0d84b..2b387679 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -111,7 +111,7 @@ end TMLEE(models; resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing) = TMLEE(models, resampling, ps_lowerbound, weighted, tol) -function (tmle::TMLEE)(Ψ, dataset; cache=Dict(), verbosity=1) +function (tmle::TMLEE)(Ψ::CMCompositeEstimand, dataset; cache=Dict(), verbosity=1) # Check the estimand against the dataset TMLE.check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's relevant factors From 0721495bab52ef3970c4f67c21bb37a5de266d06 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 24 Sep 2023 16:19:48 +0100 Subject: [PATCH 082/151] add types to structs fields --- README.md | 10 ++++--- src/TMLE.jl | 4 +-- src/counterfactual_mean_based/estimators.jl | 26 +++++++++---------- src/estimates.jl | 4 +-- src/estimators.jl | 6 ++--- .../estimators_and_estimates.jl | 6 +++++ test/estimators_and_estimates.jl | 4 +-- 7 files changed, 34 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 982a0f31..b73ba824 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,11 @@ This package enables the estimation of various Causal Inference related estimand ## Notes to self -- Factors should probably not be exposed. For instance resampling should only be specified once per estimation procedure and not for each factor. - For multiple TMLE steps, convergence when the mean of the IC is below: σ/n where σ is the standard deviation of the IC. -- Try make selectcols preserve input type -- Watch for discrepancy in train_validation_indices with missing data... \ No newline at end of file +- Try make selectcols preserve input type (see MLJModel has a facility) +- Watch for discrepancy in train_validation_indices with missing data... +- Check that for IATEs simulations, the true effect size is actually an additive interaction and not a multiplicative one. +- clean imports in test files +- clean export list +- look into dataset management and missing values +- See about ordering of parameters \ No newline at end of file diff --git a/src/TMLE.jl b/src/TMLE.jl index 3ca9b54f..91803af2 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -25,8 +25,6 @@ import AbstractDifferentiation as AD export SE, StructuralEquation export StructuralCausalModel, SCM, StaticConfoundedModel -export setequation!, getequation, get_conditional_distribution, set_conditional_distribution! -export setmodel!, equations, reset!, parents export ConditionalDistribution, ExpectedValue export CounterfactualMean, CM export AverageTreatmentEffect, ATE @@ -52,9 +50,9 @@ include("treatment_transformer.jl") include("adjustment.jl") include("counterfactual_mean_based/estimands.jl") -include("counterfactual_mean_based/estimators.jl") include("counterfactual_mean_based/estimates.jl") include("counterfactual_mean_based/fluctuation.jl") +include("counterfactual_mean_based/estimators.jl") include("counterfactual_mean_based/clever_covariate.jl") include("counterfactual_mean_based/gradient.jl") diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 2b387679..d198ccbc 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -3,8 +3,8 @@ ##################################################################### struct CMRelevantFactorsEstimator <: Estimator - resampling - models + resampling::Union{Nothing, ResamplingStrategy} + models::NamedTuple end CMRelevantFactorsEstimator(;models, resampling=nothing) = @@ -19,7 +19,7 @@ function get_train_validation_indices(resampling::ResamplingStrategy, factors, d relevant_columns = collect(TMLE.variables(factors)) outcome_variable = factors.outcome_mean.outcome feature_variables = filter(x -> x !== outcome_variable, relevant_columns) - return collect(MLJBase.train_test_pairs( + return Tuple(MLJBase.train_test_pairs( resampling, 1:nrows(dataset), selectcols(dataset, feature_variables), @@ -69,7 +69,7 @@ end ##################################################################### struct TargetedCMRelevantFactorsEstimator - model + model::Fluctuation end TargetedCMRelevantFactorsEstimator(Ψ, initial_factors_estimate; tol=nothing, ps_lowerbound=1e-8, weighted=false) = @@ -101,11 +101,11 @@ end ### TMLE ### ##################################################################### mutable struct TMLEE <: Estimator - models - resampling - ps_lowerbound - weighted - tol + models::NamedTuple + resampling::Union{Nothing, ResamplingStrategy} + ps_lowerbound::Float64 + weighted::Bool + tol::Union{Float64, Nothing} end TMLEE(models; resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing) = @@ -146,9 +146,9 @@ end ##################################################################### mutable struct OSE <: Estimator - models - resampling - ps_lowerbound + models::NamedTuple + resampling::Union{Nothing, ResamplingStrategy} + ps_lowerbound::Float64 end OSE(models; resampling=nothing, ps_lowerbound=1e-8) = @@ -184,7 +184,7 @@ end ##################################################################### mutable struct NAIVE <: Estimator - model + model::MLJBase.Supervised end function (estimator::NAIVE)(Ψ::CMCompositeEstimand, dataset; cache=Dict(), verbosity=1) diff --git a/src/estimates.jl b/src/estimates.jl index 5e33ec1d..f2b79f9a 100644 --- a/src/estimates.jl +++ b/src/estimates.jl @@ -40,8 +40,8 @@ Holds a Sample Split Machine Learning estimate for a Conditional Distribution. """ struct SampleSplitMLConditionalDistribution <: Estimate estimand::ConditionalDistribution - train_validation_indices - machines + train_validation_indices::Tuple + machines::Vector{Machine} end string_repr(estimate::SampleSplitMLConditionalDistribution) = diff --git a/src/estimators.jl b/src/estimators.jl index 5ac87a28..7c85bea7 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -5,7 +5,7 @@ abstract type Estimator end ##################################################################### struct MLConditionalDistributionEstimator <: Estimator - model + model::MLJBase.Supervised end function (estimator::MLConditionalDistributionEstimator)(estimand, dataset; cache=Dict(), verbosity=1) @@ -41,8 +41,8 @@ key(estimator::MLConditionalDistributionEstimator) = ##################################################################### struct SampleSplitMLConditionalDistributionEstimator <: Estimator - model - train_validation_indices + model::MLJBase.Supervised + train_validation_indices::Tuple end function (estimator::SampleSplitMLConditionalDistributionEstimator)(estimand, dataset; cache=Dict(), verbosity=1) diff --git a/test/counterfactual_mean_based/estimators_and_estimates.jl b/test/counterfactual_mean_based/estimators_and_estimates.jl index d5f371fa..dcd80468 100644 --- a/test/counterfactual_mean_based/estimators_and_estimates.jl +++ b/test/counterfactual_mean_based/estimators_and_estimates.jl @@ -74,6 +74,12 @@ using MLJBase @test η̂ₙ.outcome_mean.train_validation_indices == η̂ₙ.propensity_score[1].train_validation_indices end +@testset "Test structs are concrete types" begin + for type in (OSE, TMLEE, NAIVE) + @test isconcretetype(type) + end +end + end true \ No newline at end of file diff --git a/test/estimators_and_estimates.jl b/test/estimators_and_estimates.jl index 4fd0fe65..94bf2f65 100644 --- a/test/estimators_and_estimates.jl +++ b/test/estimators_and_estimates.jl @@ -42,7 +42,7 @@ end @testset "Test SampleSplitMLConditionalDistributionEstimator" begin nfolds = 3 - train_validation_indices = collect(MLJBase.train_test_pairs(CV(nfolds=nfolds), 1:n, dataset)) + train_validation_indices = Tuple(MLJBase.train_test_pairs(CV(nfolds=nfolds), 1:n, dataset)) model = LinearBinaryClassifier() estimator = TMLE.SampleSplitMLConditionalDistributionEstimator( model, @@ -79,7 +79,7 @@ end ) @test_logs (:info, fit_log) new_estimator(estimand, dataset; cache=cache, verbosity=verbosity) # Changing the train/validation splits leads to refit - train_validation_indices = collect(MLJBase.train_test_pairs(CV(nfolds=4), 1:n, dataset)) + train_validation_indices = Tuple(MLJBase.train_test_pairs(CV(nfolds=4), 1:n, dataset)) new_estimator = TMLE.SampleSplitMLConditionalDistributionEstimator( new_model, train_validation_indices From 5cafe379aafa88aae08ba43d06b4cf778db3e94f Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 25 Sep 2023 10:16:59 +0100 Subject: [PATCH 083/151] clean some files --- .../double_robustness_ate.jl | 1 - test/counterfactual_mean_based/fluctuation.jl | 1 - .../non_regression_test.jl | 1 - test/distribution_factors.jl | 69 ------------------- test/estimands.jl | 1 - test/treatment_transformer.jl | 1 + test/utils.jl | 1 - 7 files changed, 1 insertion(+), 74 deletions(-) delete mode 100644 test/distribution_factors.jl diff --git a/test/counterfactual_mean_based/double_robustness_ate.jl b/test/counterfactual_mean_based/double_robustness_ate.jl index d0c75920..236da304 100644 --- a/test/counterfactual_mean_based/double_robustness_ate.jl +++ b/test/counterfactual_mean_based/double_robustness_ate.jl @@ -9,7 +9,6 @@ using MLJLinearModels using MLJModels using StableRNGs using StatsBase -using HypothesisTests using LogExpFunctions include(joinpath(dirname(@__DIR__), "helper_fns.jl")) diff --git a/test/counterfactual_mean_based/fluctuation.jl b/test/counterfactual_mean_based/fluctuation.jl index ceceb986..c20423a8 100644 --- a/test/counterfactual_mean_based/fluctuation.jl +++ b/test/counterfactual_mean_based/fluctuation.jl @@ -6,7 +6,6 @@ using MLJModels using MLJBase using DataFrames - @testset "Test Fluctuation with 1 Treatments" begin Ψ = ATE( outcome=:Y, diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index e67ad5c9..11f07964 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -8,7 +8,6 @@ using CategoricalArrays using MLJGLMInterface using MLJBase - @testset "Test ATE on perinatal dataset." begin # This is a non-regression test which was checked against the R tmle3 package dataset = CSV.read(joinpath("data", "perinatal.csv"), DataFrame, missingstring=["", "NA"]) diff --git a/test/distribution_factors.jl b/test/distribution_factors.jl deleted file mode 100644 index 0ed19925..00000000 --- a/test/distribution_factors.jl +++ /dev/null @@ -1,69 +0,0 @@ -module TestDistributionFactors - -using Test -using TMLE -using MLJGLMInterface -using MLJBase - -@testset "Test ConditionalDistribution(outcome, parents, model, resampling)" begin - # Default constructor - cond_dist = ConditionalDistribution("Y", [:W, "T", :C], LinearBinaryClassifier()) - @test cond_dist.outcome == :Y - @test cond_dist.parents == Set([:W, :T, :C]) - @test cond_dist.model == LinearBinaryClassifier() - @test TMLE.key(cond_dist) == (:Y, Set([:W, :T, :C])) - # String representation - @test TMLE.string_repr(cond_dist) == "Y | T, W, C" -end - -@testset "Test fit!" begin - X, y = make_regression() - dataset = (Y = y, X₁ = X.x1, X₂ = X.x2) - cond_dist = ConditionalDistribution(:Y, [:X₁, :X₂], LinearRegressor()) - # Initial fit - log_sequence = ( - (:info, TMLE.fit_message(cond_dist)), - (:info, "Training machine(LinearRegressor(fit_intercept = true, …), …).") - ) - @test_logs log_sequence... MLJBase.fit!(cond_dist, dataset, verbosity=2) - fp = fitted_params(cond_dist.machine) - @test fp.features == [:X₂, :X₁] - @test fp.intercept != 0. - # Change the model - cond_dist.model = LinearRegressor(fit_intercept=false) - log_sequence = ( - (:info, TMLE.fit_message(cond_dist)), - (:info, "Training machine(LinearRegressor(fit_intercept = false, …), …).") - ) - @test_logs log_sequence... MLJBase.fit!(cond_dist, dataset, verbosity=2) - fp = fitted_params(cond_dist.machine) - @test fp.features == [:X₂, :X₁] - @test fp.intercept == 0. - # Change model's hyperparameter - cond_dist.model.fit_intercept = true - log_sequence = ( - (:info, TMLE.update_message(cond_dist)), - (:info, "Updating machine(LinearRegressor(fit_intercept = true, …), …).") - ) - @test_logs log_sequence... MLJBase.fit!(cond_dist, dataset, verbosity=2) - fp = fitted_params(cond_dist.machine) - @test fp.features == [:X₂, :X₁] - @test fp.intercept != 0. - # Forcing refit - log_sequence = ( - (:info, TMLE.update_message(cond_dist)), - (:info, "Training machine(LinearRegressor(fit_intercept = true, …), …).") - ) - @test_logs log_sequence... MLJBase.fit!(cond_dist, dataset, verbosity=2, force = true) - # No refit - log_sequence = ( - (:info, TMLE.update_message(cond_dist)), - (:info, "Not retraining machine(LinearRegressor(fit_intercept = true, …), …). Use `force=true` to force.") - ) - @test_logs log_sequence... MLJBase.fit!(cond_dist, dataset, verbosity=2) - -end - -end - -true \ No newline at end of file diff --git a/test/estimands.jl b/test/estimands.jl index ee751731..e239ea9a 100644 --- a/test/estimands.jl +++ b/test/estimands.jl @@ -3,7 +3,6 @@ module TestEstimands using TMLE using Test - @testset "Test ConditionalDistribution" begin distr = ConditionalDistribution("Y", ["C", 1, :A, ]) @test distr.outcome === :Y diff --git a/test/treatment_transformer.jl b/test/treatment_transformer.jl index f904bf56..9302af0a 100644 --- a/test/treatment_transformer.jl +++ b/test/treatment_transformer.jl @@ -1,4 +1,5 @@ module TestTreatmentTransformer + using Test using TMLE using CategoricalArrays diff --git a/test/utils.jl b/test/utils.jl index 2ac87525..76725fd7 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -3,7 +3,6 @@ module TestUtils using Test using TMLE using MLJBase -using StableRNGs using CategoricalArrays using MLJLinearModels using MLJModels From 1e77a440e9da4c1ca8abfd0230232366d71acaed Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 25 Sep 2023 13:40:43 +0100 Subject: [PATCH 084/151] towards inclusion of scm --- src/TMLE.jl | 5 +- src/counterfactual_mean_based/estimands.jl | 73 +++--- src/scm.jl | 270 +++++---------------- 3 files changed, 90 insertions(+), 258 deletions(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index 91803af2..f8d41e1c 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -18,19 +18,20 @@ using PrecompileTools using PrettyTables using Random import AbstractDifferentiation as AD +using Graphs +using MetaGraphsNext # ############################################################################# # EXPORTS # ############################################################################# -export SE, StructuralEquation export StructuralCausalModel, SCM, StaticConfoundedModel export ConditionalDistribution, ExpectedValue export CounterfactualMean, CM export AverageTreatmentEffect, ATE export InteractionAverageTreatmentEffect, IATE export AVAILABLE_ESTIMANDS -export fit!, optimize_ordering, optimize_ordering! +export optimize_ordering, optimize_ordering! export TMLEE, OSE, NAIVE export var, estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index 122e5605..b58a6e87 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -38,42 +38,30 @@ ESTIMANDS_DOCS = Dict( :ATE => (formula="``ATE(Y, T, case, control) = E[Y|do(T=case)] - E[Y|do(T=control)``",), :IATE => (formula="``IATE = E[Y|do(T₁=1, T₂=1)] - E[Y|do(T₁=1, T₂=0)] - E[Y|do(T₁=0, T₂=1)] + E[Y|do(T₁=0, T₂=0)]``",) ) -# Define constructors/name for CMCompositeEstimand types -for (typename, (formula,)) ∈ ESTIMANDS_DOCS +for (estimand, (formula,)) ∈ ESTIMANDS_DOCS + causal_estimand = Symbol(:Causal, estimand) + statistical_estimand = Symbol(:Statistical, estimand) ex = quote - # """ - # # $(typename) - - # ## Definition - - # For two treatments with case/control settings (1, 0): - - # $(formula) - - # ## Constructors - - # - $(typename)(outcome, treatment, confounders; extra_outcome_covariates=Set{Symbol}()) - # - $(typename)(; - # outcome, - # treatment, - # confounders, - # extra_outcome_covariates - # ) - - # ## Example + # Causal Estimand + struct $(causal_estimand) <: Estimand + outcome::Symbol + treatment_values::NamedTuple - # ```julia - # Ψ = $(typename)(outcome=:Y, treatment=(T₁=(case=1, control=0), T₂=(case=1,control=0), confounders=[:W₁, :W₂]) - # ``` - # """ - struct $(typename) <: Estimand + function $(statistical_estimand)(outcome, treatment_values) + outcome = Symbol(outcome) + treatment_variables = Tuple(keys(treatment_values)) + return new(outcome, treatment_values) + end + end + # Statistical Estimand + struct $(statistical_estimand) <: Estimand outcome::Symbol treatment_values::NamedTuple treatment_confounders::NamedTuple outcome_extra_covariates::Tuple{Vararg{Symbol}} - function $(typename)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) + function $(statistical_estimand)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) outcome = Symbol(outcome) treatment_variables = Tuple(keys(treatment_values)) treatment_confounders = NamedTuple{treatment_variables}([unique_sorted_tuple(treatment_confounders[T]) for T ∈ treatment_variables]) @@ -82,22 +70,27 @@ for (typename, (formula,)) ∈ ESTIMANDS_DOCS end end - $(typename)(outcome, treatment_values, treatment_confounders; outcome_extra_covariates=()) = - $(typename)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) - - $(typename)(;outcome, treatment_values, treatment_confounders, outcome_extra_covariates=()) = - $(typename)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) + # Constructors + $(estimand)(outcome, treatment_values) = $(causal_estimand)(outcome, treatment_values) + + $(estimand)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) = + $(statistical_estimand)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) + + $(estimand)(outcome, treatment_values, treatment_confounders::Nothing, outcome_extra_covariates) = + $(statistical_estimand)(outcome, treatment_values) + + $(estimand)(;outcome, treatment_values, treatment_confounders, outcome_extra_covariates=()) = + $(estimand)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) - name(::Type{$(typename)}) = string($(typename)) end eval(ex) end -CMCompositeEstimand = Union{(eval(x) for x in keys(ESTIMANDS_DOCS))...} +CMCompositeEstimand = Union{(eval(Symbol(:Statistical, x)) for x in keys(ESTIMANDS_DOCS))...} -indicator_fns(Ψ::CM) = Dict(values(Ψ.treatment_values) => 1.) +indicator_fns(Ψ::StatisticalCM) = Dict(values(Ψ.treatment_values) => 1.) -function indicator_fns(Ψ::ATE) +function indicator_fns(Ψ::StatisticalATE) case = [] control = [] for treatment in Ψ.treatment_values @@ -107,9 +100,9 @@ function indicator_fns(Ψ::ATE) return Dict(Tuple(case) => 1., Tuple(control) => -1.) end -ncases(value, Ψ::IATE) = sum(value[i] == Ψ.treatment_values[i].case for i in eachindex(value)) +ncases(value, Ψ::StatisticalIATE) = sum(value[i] == Ψ.treatment_values[i].case for i in eachindex(value)) -function indicator_fns(Ψ::IATE) +function indicator_fns(Ψ::StatisticalIATE) N = length(treatments(Ψ)) key_vals = Pair[] for cf in Iterators.product((values(Ψ.treatment_values[T]) for T in treatments(Ψ))...) @@ -126,7 +119,7 @@ end function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: CMCompositeEstimand param_string = string( - name(T), + Base.typename(T).wrapper, "\n-----", "\nOutcome: ", Ψ.outcome, "\nTreatment: ", Ψ.treatment_values diff --git a/src/scm.jl b/src/scm.jl index cfd0bc4e..6055f4c9 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -2,238 +2,76 @@ ### Structural Equation ### ##################################################################### -SelfReferringEquationError(outcome) = - ArgumentError(string("Variable ", outcome, " appears on both sides of the equation.")) - -""" -# Structural Equation / SE - -## Constructors - -- SE(outcome, parents) -- SE(;outcome, parents) - -## Examples - -eq = SE(:Y, [:T, :W]) -eq = SE(outcome=:Y, parents=[:T, :W]) - -""" -struct StructuralEquation - outcome::Symbol - parents::Set{Symbol} - function StructuralEquation(outcome, parents) - outcome = Symbol(outcome) - parents = Set(Symbol(x) for x in parents) - outcome ∉ parents || throw(SelfReferringEquationError(outcome)) - return new(outcome, parents) +function SCM(equations...) + scm = MetaGraph( + SimpleDiGraph(); + label_type=Symbol, + vertex_data_type=Nothing, + edge_data_type=Nothing + ) + for (outcome, parents) in equations + add_equation!(scm, outcome, parents) end + return scm end -const SE = StructuralEquation - -SE(;outcome, parents) = SE(outcome, parents) - -outcome(se::SE) = se.outcome -parents(se::SE) = se.parents +function add_equation!(scm::MetaGraph, outcome, parents) + outcome_symbol = Symbol(outcome) + add_vertex!(scm, outcome_symbol) + for parent in parents + parent_symbol = Symbol(parent) + add_vertex!(scm, parent_symbol) + add_edge!(scm, parent_symbol, outcome_symbol) + end +end -string_repr(eq::SE; subscript="") = string(outcome(eq), " = f", subscript, "(", join(parents(eq), ", "), ")") -Base.show(io::IO, ::MIME"text/plain", eq::SE) = println(io, string_repr(eq)) +get_outcome_extra_covariates(outcome_parents, treatment_and_confounders, vertex_labels) = + [vertex_labels[p] for p ∈ outcome_parents if p ∉ treatment_and_confounders] ##################################################################### -### Structural Causal Model ### +### Identification Methods ### ##################################################################### -AlreadyAssignedError(key) = ArgumentError(string("Variable ", key, " is already assigned in the SCM.")) - -""" -# Structural Causal Model / SCM - -## Constructors - -SCM(;equations=Dict{Symbol, SE}()) -SCM(equations::Vararg{SE}) +abstract type AdjustmentMethod end -## Examples - -scm = SCM( - SE(:Y, [:T, :W, :C]), - SE(:T, [:W]) -) - -""" -struct StructuralCausalModel - equations::Dict{Symbol, StructuralEquation} - factors::Dict - StructuralCausalModel(equations) = new(equations, Dict()) -end - -const SCM = StructuralCausalModel - -SCM(;equations=Dict{Symbol, SE}()) = SCM(equations) -SCM(equations::Vararg{SE}) = SCM(Dict(outcome(eq) => eq for eq in equations)) - -equations(scm::SCM) = scm.equations -factors(scm::SCM) = scm.factors - -parents(scm::StructuralCausalModel, outcome::Symbol) = - haskey(equations(scm), outcome) ? parents(equations(scm)[outcome]) : Set{Symbol}() - -setequation!(scm::SCM, eq::SE) = scm.equations[outcome(eq)] = eq - -getequation(scm::StructuralCausalModel, key::Symbol) = scm.equations[key] - -get_conditional_distribution(scm::SCM, outcome::Symbol) = get_conditional_distribution(scm, outcome, parents(scm, outcome)) -get_conditional_distribution(scm::SCM, outcome::Symbol, parents) = scm.factors[(outcome, Set(parents))] - -NoAvailableConditionalDistributionError(scm, outcome, conditioning_set) = ArgumentError(string( - "Could not find a conditional distribution for either : \n - The required: ", - cond_dist_string(outcome, conditioning_set), - "\n - The natural (that could be used as a template): ", - cond_dist_string(outcome, parents(scm, outcome)), - "\nSet at least one of them using `set_conditional_distribution!` before proceeding to estimation." - )) - -UsingNaturalDistributionLog(scm, outcome, conditioning_set) = string( - "Could not retrieve a conditional distribution for the required: ", - cond_dist_string(outcome, conditioning_set), - ".\n Will use the model class found in the natural candidate: ", - cond_dist_string(outcome, parents(scm, outcome)) -) - -function get_or_set_conditional_distribution_from_natural!(scm::SCM, outcome::Symbol, conditioning_set::Set{Symbol}; resampling=nothing, verbosity=1) - try - # A factor has already been setup - factor = get_conditional_distribution(scm, outcome, conditioning_set) - # If the resampling strategies also match then we can reuse the already set factor - new_factor = ConditionalDistribution(outcome, conditioning_set, factor.model, resampling=resampling) - if match(factor, new_factor) - return factor - else - setfactor!(scm, new_factor) - return new_factor - end - catch KeyError - # A factor has not already been setup - try - # A natural factor can be used as a template - template_factor = get_conditional_distribution(scm, outcome) - new_factor = set_conditional_distribution!(scm, outcome, conditioning_set, template_factor.model, resampling) - verbosity > 0 && @info(UsingNaturalDistributionLog(scm, outcome, conditioning_set)) - return new_factor - catch KeyError - throw(NoAvailableConditionalDistributionError(scm, outcome, conditioning_set)) - end - end +struct BackdoorAdjustment <: AdjustmentMethod + outcome_extra::Bool end -set_conditional_distribution!(scm::SCM, outcome::Symbol, model, resampling) = - set_conditional_distribution!(scm, outcome, parents(scm, outcome), model, resampling) - -function set_conditional_distribution!(scm::SCM, outcome::Symbol, parents, model, resampling) - factor = ConditionalDistribution(outcome, parents, model, resampling=resampling) - return setfactor!(scm, factor) -end - -function setfactor!(scm::SCM, factor) - scm.factors[key(factor)] = factor - return factor -end - -function string_repr(scm::SCM) - scm_string = """ - Structural Causal Model: - ----------------------- - """ - for (index, (_, eq)) in enumerate(scm.equations) - digits = split(string(index), "") - subscript = join(Meta.parse("'\\U0208$digit'") for digit in digits) - scm_string = string(scm_string, string_repr(eq;subscript=subscript), "\n") - end - return scm_string -end - -Base.show(io::IO, ::MIME"text/plain", scm::SCM) = println(io, string_repr(scm)) - - -""" - is_upstream(var₁, var₂, scm::SCM) - -Checks whether var₁ is upstream of var₂ in the SCM. -""" -function is_upstream(var₁, var₂, scm::SCM) - upstream = parents(scm, var₂) - while true - if var₁ ∈ upstream - return true - else - upstream = union((parents(scm, v) for v in upstream)...) - end - isempty(upstream) && return false +function identify(method::BackdoorAdjustment, estimand::T, scm::MetaGraph) where T<:Union{Cau} + # Treatment confounders + treatment_names = keys(estimand.treatment) + treatment_codes = [code_for(scm, treatment) for treatment ∈ treatment_names] + confounders_codes = scm.graph.badjlist[treatment_codes] + treatment_confounders = NamedTuple{treatment_names}( + [[scm.vertex_labels[w] for w in confounders_codes[i]] + for i in eachindex(confounders_codes)] + ) + # Extra covariates + outcome_extra_covariates = nothing + if method.outcome_extra_covariates + outcome_parents_codes = scm.graph.badjlist[code_for(scm, estimand.outcome)] + treatment_and_confounders_codes = Set(vcat(treatment_codes, confounders_codes...)) + outcome_extra_covariates = get_outcome_extra_covariates( + outcome_parents_codes, + treatment_and_confounders_codes, + scm.vertex_labels + ) end -end - -VariableNotAChildInSCMError(variable) = ArgumentError(string("Variable ", variable, " is not associated with a Structural Equation in the SCM.")) -TreatmentMustBeInOutcomeParentsError(variable) = ArgumentError(string("Treatment variable ", variable, " must be a parent of the outcome.")) -function check_parameter_against_scm(scm::SCM, outcome, treatment) - eqs = equations(scm) - haskey(eqs, outcome) || throw(VariableNotAChildInSCMError(outcome)) - for treatment_variable in keys(treatment) - haskey(eqs, treatment_variable) || throw(VariableNotAChildInSCMError(treatment_variable)) - is_upstream(treatment_variable, outcome, scm) || throw(TreatmentMustBeInOutcomeParentsError(treatment_variable)) - end + return T( + outcome=estimand.outcome, + treatment_values = estimand.treatment_values, + treatment_confounders = treatment_confounders, + outcome_extra_covariates = outcome_extra_covariates + ) end ##################################################################### -### StaticConfoundedModel ### +### Causal Estimand ### ##################################################################### - -combine_outcome_parents(treatment, confounders, ::Nothing) = vcat(treatment, confounders) -combine_outcome_parents(treatment, confounders, covariates) = vcat(treatment, confounders, covariates) - -""" - StaticConfoundedModel( - outcome::Symbol, treatment::Symbol, confounders::Union{Symbol, AbstractVector{Symbol}}; - covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing - ) - -Defines a classic Structural Causal Model with one outcome, one treatment, -a set of confounding variables and optional covariates influencing the outcome only. -""" -function StaticConfoundedModel( - outcome, treatment, confounders; - covariates = nothing - ) - Yeq = SE(outcome, combine_outcome_parents(treatment, confounders, covariates)) - Teq = SE(treatment, vcat(confounders)) - return StructuralCausalModel(Yeq, Teq) -end - -""" - StaticConfoundedModel( - outcomes::Vector{Symbol}, - treatments::Vector{Symbol}, - confounders::Union{Symbol, AbstractVector{Symbol}}; - covariates::Union{Nothing, Symbol, AbstractVector{Symbol}} = nothing, - outcome_model = TreatmentTransformer() |> LinearRegressor(), - treatment_model = LinearBinaryClassifier() - ) - -Defines a classic Structural Causal Model with multiple outcomes, multiple treatments, -a set of confounding variables and optional covariates influencing the outcomes only. - -All treatments are assumed to be direct parents of all outcomes. The confounding variables -are shared for all treatments. -""" -function StaticConfoundedModel( - outcomes::AbstractVector, - treatments::AbstractVector, - confounders; - covariates = nothing - ) - Yequations = (SE(outcome, combine_outcome_parents(treatments, confounders, covariates)) for outcome in outcomes) - Tequations = (SE(treatment, vcat(confounders)) for treatment in treatments) - return SCM(Yequations..., Tequations...) +struct CausalATE + outcome::Symbol + treatment_values::NamedTuple end \ No newline at end of file From 8a81c51e9b4c856f49a9a0819960a024deac1880 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 25 Sep 2023 15:48:18 +0100 Subject: [PATCH 085/151] add scm and identification --- src/TMLE.jl | 10 +- src/adjustment.jl | 28 ----- src/counterfactual_mean_based/adjustment.jl | 37 ++++++ .../clever_covariate.jl | 6 +- src/counterfactual_mean_based/estimands.jl | 16 +-- src/counterfactual_mean_based/estimators.jl | 6 +- src/counterfactual_mean_based/fluctuation.jl | 2 +- src/counterfactual_mean_based/gradient.jl | 8 +- src/scm.jl | 79 +++++------- test/adjustment.jl | 28 ----- test/counterfactual_mean_based/adjustment.jl | 35 ++++++ test/counterfactual_mean_based/estimands.jl | 4 +- test/estimands.jl | 12 +- test/estimators_and_estimates.jl | 2 +- test/runtests.jl | 7 +- test/scm.jl | 114 +++++------------- 16 files changed, 170 insertions(+), 224 deletions(-) delete mode 100644 src/adjustment.jl create mode 100644 src/counterfactual_mean_based/adjustment.jl delete mode 100644 test/adjustment.jl create mode 100644 test/counterfactual_mean_based/adjustment.jl diff --git a/src/TMLE.jl b/src/TMLE.jl index f8d41e1c..b6a80723 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -25,8 +25,7 @@ using MetaGraphsNext # EXPORTS # ############################################################################# -export StructuralCausalModel, SCM, StaticConfoundedModel -export ConditionalDistribution, ExpectedValue +export SCM, StaticSCM, add_equations!, add_equation!, parents export CounterfactualMean, CM export AverageTreatmentEffect, ATE export InteractionAverageTreatmentEffect, IATE @@ -36,19 +35,18 @@ export TMLEE, OSE, NAIVE export var, estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose export TreatmentTransformer, with_encoder -export BackdoorAdjustment +export BackdoorAdjustment, identify # ############################################################################# # INCLUDES # ############################################################################# include("utils.jl") -# include("scm.jl") +include("scm.jl") include("estimands.jl") include("estimators.jl") include("estimates.jl") include("treatment_transformer.jl") -include("adjustment.jl") include("counterfactual_mean_based/estimands.jl") include("counterfactual_mean_based/estimates.jl") @@ -56,6 +54,8 @@ include("counterfactual_mean_based/fluctuation.jl") include("counterfactual_mean_based/estimators.jl") include("counterfactual_mean_based/clever_covariate.jl") include("counterfactual_mean_based/gradient.jl") +include("counterfactual_mean_based/adjustment.jl") + # ############################################################################# diff --git a/src/adjustment.jl b/src/adjustment.jl deleted file mode 100644 index 3853ca05..00000000 --- a/src/adjustment.jl +++ /dev/null @@ -1,28 +0,0 @@ - -abstract type AdjustmentMethod end - -struct BackdoorAdjustment <: AdjustmentMethod - outcome_extra::Vector{Symbol} -end - -""" - BackdoorAdjustment(;outcome_extra=[]) - -The adjustment set for each treatment variable is simply the set of direct parents in the -associated structural model. - -`outcome_extra` are optional additional variables that can be used to fit the outcome model -in order to improve inference. -""" -BackdoorAdjustment(;outcome_extra=[]) = BackdoorAdjustment(outcome_extra) - -confounders(::BackdoorAdjustment, Ψ::Estimand) = - union((parents(Ψ.scm, treatment) for treatment in treatments(Ψ))...) - -outcome_parents(adjustment_method::BackdoorAdjustment, Ψ::Estimand) = - Set(vcat( - treatments(Ψ), - confounders(adjustment_method, Ψ)..., - adjustment_method.outcome_extra - )) - diff --git a/src/counterfactual_mean_based/adjustment.jl b/src/counterfactual_mean_based/adjustment.jl new file mode 100644 index 00000000..79e74a3c --- /dev/null +++ b/src/counterfactual_mean_based/adjustment.jl @@ -0,0 +1,37 @@ +##################################################################### +### Identification Methods ### +##################################################################### + +abstract type AdjustmentMethod end + +struct BackdoorAdjustment <: AdjustmentMethod + outcome_extra_covariates::Tuple{Vararg{Symbol}} + BackdoorAdjustment(outcome_extra_covariates) = new(unique_sorted_tuple(outcome_extra_covariates)) +end + +BackdoorAdjustment(;outcome_extra_covariates=false) = + BackdoorAdjustment(outcome_extra_covariates) + +function statistical_type_from_causal_type(T) + typestring = string(Base.typename(T).wrapper) + new_typestring = replace(typestring, "TMLE.Causal" => "") + return eval(Symbol(new_typestring)) +end + +function identify(method::BackdoorAdjustment, causal_estimand::T, scm::MetaGraph) where T<:CausalCMCompositeEstimands + # Treatment confounders + treatment_names = keys(causal_estimand.treatment_values) + treatment_codes = [code_for(scm, treatment) for treatment ∈ treatment_names] + confounders_codes = scm.graph.badjlist[treatment_codes] + treatment_confounders = NamedTuple{treatment_names}( + [[scm.vertex_labels[w] for w in confounders_codes[i]] + for i in eachindex(confounders_codes)] + ) + + return statistical_type_from_causal_type(T)(; + outcome=causal_estimand.outcome, + treatment_values = causal_estimand.treatment_values, + treatment_confounders = treatment_confounders, + outcome_extra_covariates = method.outcome_extra_covariates + ) +end \ No newline at end of file diff --git a/src/counterfactual_mean_based/clever_covariate.jl b/src/counterfactual_mean_based/clever_covariate.jl index 7fbb2f25..d3d2e2cc 100644 --- a/src/counterfactual_mean_based/clever_covariate.jl +++ b/src/counterfactual_mean_based/clever_covariate.jl @@ -1,6 +1,6 @@ """ - data_adaptive_ps_lower_bound(Ψ::CMCompositeEstimand) + data_adaptive_ps_lower_bound(Ψ::StatisticalCMCompositeEstimand) This startegy is from [this paper](https://academic.oup.com/aje/article/191/9/1640/6580570?login=false) but the study does not show strictly better behaviour of the strategy so not a default for now. @@ -29,7 +29,7 @@ end """ clever_covariate_and_weights( - Ψ::CMCompositeEstimand, + Ψ::StatisticalCMCompositeEstimand, Gs::Tuple{Vararg{ConditionalDistributionEstimate}}, dataset; ps_lowerbound=1e-8, @@ -51,7 +51,7 @@ if `weighted_fluctuation = true`: where SpecialIndicator(t) is defined in `indicator_fns`. """ function clever_covariate_and_weights( - Ψ::CMCompositeEstimand, + Ψ::StatisticalCMCompositeEstimand, Gs::Tuple{Vararg{ConditionalDistributionEstimate}}, dataset; ps_lowerbound=1e-8, diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index b58a6e87..8992ecc7 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -4,7 +4,7 @@ """ Defines relevant factors that need to be estimated in order to estimate any -Counterfactual Mean composite estimand (see `CMCompositeEstimand`). +Counterfactual Mean composite estimand (see `StatisticalCMCompositeEstimand`). """ struct CMRelevantFactors <: Estimand outcome_mean::ConditionalDistribution @@ -48,7 +48,7 @@ for (estimand, (formula,)) ∈ ESTIMANDS_DOCS outcome::Symbol treatment_values::NamedTuple - function $(statistical_estimand)(outcome, treatment_values) + function $(causal_estimand)(outcome, treatment_values) outcome = Symbol(outcome) treatment_variables = Tuple(keys(treatment_values)) return new(outcome, treatment_values) @@ -77,16 +77,18 @@ for (estimand, (formula,)) ∈ ESTIMANDS_DOCS $(statistical_estimand)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) $(estimand)(outcome, treatment_values, treatment_confounders::Nothing, outcome_extra_covariates) = - $(statistical_estimand)(outcome, treatment_values) + $(causal_estimand)(outcome, treatment_values) - $(estimand)(;outcome, treatment_values, treatment_confounders, outcome_extra_covariates=()) = + $(estimand)(;outcome, treatment_values, treatment_confounders=nothing, outcome_extra_covariates=()) = $(estimand)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) end eval(ex) end -CMCompositeEstimand = Union{(eval(Symbol(:Statistical, x)) for x in keys(ESTIMANDS_DOCS))...} +CausalCMCompositeEstimands = Union{(eval(Symbol(:Causal, x)) for x in keys(ESTIMANDS_DOCS))...} + +StatisticalCMCompositeEstimand = Union{(eval(Symbol(:Statistical, x)) for x in keys(ESTIMANDS_DOCS))...} indicator_fns(Ψ::StatisticalCM) = Dict(values(Ψ.treatment_values) => 1.) @@ -111,13 +113,13 @@ function indicator_fns(Ψ::StatisticalIATE) return Dict(key_vals...) end -function get_relevant_factors(Ψ::CMCompositeEstimand) +function get_relevant_factors(Ψ::StatisticalCMCompositeEstimand) outcome_model = ExpectedValue(Ψ.outcome, Tuple(union(Ψ.outcome_extra_covariates, keys(Ψ.treatment_confounders), (Ψ.treatment_confounders)...))) treatment_factors = Tuple(ConditionalDistribution(T, Ψ.treatment_confounders[T]) for T in treatments(Ψ)) return CMRelevantFactors(outcome_model, treatment_factors) end -function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: CMCompositeEstimand +function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: StatisticalCMCompositeEstimand param_string = string( Base.typename(T).wrapper, "\n-----", diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index d198ccbc..19177d7a 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -111,7 +111,7 @@ end TMLEE(models; resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing) = TMLEE(models, resampling, ps_lowerbound, weighted, tol) -function (tmle::TMLEE)(Ψ::CMCompositeEstimand, dataset; cache=Dict(), verbosity=1) +function (tmle::TMLEE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1) # Check the estimand against the dataset TMLE.check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's relevant factors @@ -154,7 +154,7 @@ end OSE(models; resampling=nothing, ps_lowerbound=1e-8) = OSE(models, resampling, ps_lowerbound) -function (estimator::OSE)(Ψ::CMCompositeEstimand, dataset; cache=Dict(), verbosity=1) +function (estimator::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1) # Check the estimand against the dataset TMLE.check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's relevant factors @@ -187,7 +187,7 @@ mutable struct NAIVE <: Estimator model::MLJBase.Supervised end -function (estimator::NAIVE)(Ψ::CMCompositeEstimand, dataset; cache=Dict(), verbosity=1) +function (estimator::NAIVE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1) # Check the estimand against the dataset TMLE.check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's relevant factors diff --git a/src/counterfactual_mean_based/fluctuation.jl b/src/counterfactual_mean_based/fluctuation.jl index 176f2917..63613064 100644 --- a/src/counterfactual_mean_based/fluctuation.jl +++ b/src/counterfactual_mean_based/fluctuation.jl @@ -1,5 +1,5 @@ mutable struct Fluctuation <: MLJBase.Supervised - Ψ::CMCompositeEstimand + Ψ::StatisticalCMCompositeEstimand initial_factors::TMLE.MLCMRelevantFactors tol::Union{Nothing, Float64} ps_lowerbound::Float64 diff --git a/src/counterfactual_mean_based/gradient.jl b/src/counterfactual_mean_based/gradient.jl index 89c2009c..ff715039 100644 --- a/src/counterfactual_mean_based/gradient.jl +++ b/src/counterfactual_mean_based/gradient.jl @@ -1,5 +1,5 @@ """ - counterfactual_aggregate(Ψ::CMCompositeEstimand, Q::Machine) + counterfactual_aggregate(Ψ::StatisticalCMCompositeEstimand, Q::Machine) This is a counterfactual aggregate that depends on the parameter of interest. @@ -8,7 +8,7 @@ For the ATE with binary treatment, confounded by W it is equal to: ``ctf_agg = Q(1, W) - Q(0, W)`` """ -function counterfactual_aggregate(Ψ::CMCompositeEstimand, Q, dataset) +function counterfactual_aggregate(Ψ::StatisticalCMCompositeEstimand, Q, dataset) X = selectcols(dataset, Q.estimand.parents) Ttemplate = selectcols(X, treatments(Ψ)) n = nrows(Ttemplate) @@ -44,7 +44,7 @@ compute_estimate(ctf_aggregate, train_validation_indices) = This part of the gradient is evaluated on the original dataset. All quantities have been precomputed and cached. """ -function ∇YX(Ψ::CMCompositeEstimand, Q, G, dataset; ps_lowerbound=1e-8) +function ∇YX(Ψ::StatisticalCMCompositeEstimand, Q, G, dataset; ps_lowerbound=1e-8) # Maybe can cache some results (H and E[Y|X]) to improve perf here H, weights = clever_covariate_and_weights(Ψ, G, dataset; ps_lowerbound=ps_lowerbound) y = float(Tables.getcolumn(dataset, Q.estimand.outcome)) @@ -53,7 +53,7 @@ function ∇YX(Ψ::CMCompositeEstimand, Q, G, dataset; ps_lowerbound=1e-8) end -function gradient_and_estimate(Ψ::CMCompositeEstimand, factors, dataset; ps_lowerbound=1e-8) +function gradient_and_estimate(Ψ::StatisticalCMCompositeEstimand, factors, dataset; ps_lowerbound=1e-8) Q = factors.outcome_mean G = factors.propensity_score ctf_agg = TMLE.counterfactual_aggregate(Ψ, Q, dataset) diff --git a/src/scm.jl b/src/scm.jl index 6055f4c9..4487ea53 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -2,6 +2,11 @@ ### Structural Equation ### ##################################################################### +""" +A SCM is simply a MetaGraph over a Directed Acyclic Graph with additional methods. +""" +const SCM = MetaGraph + function SCM(equations...) scm = MetaGraph( SimpleDiGraph(); @@ -9,13 +14,18 @@ function SCM(equations...) vertex_data_type=Nothing, edge_data_type=Nothing ) - for (outcome, parents) in equations - add_equation!(scm, outcome, parents) - end + add_equations!(scm, equations...) return scm end -function add_equation!(scm::MetaGraph, outcome, parents) +function add_equations!(scm::SCM, equations...) + for outcome_parents_pair in equations + add_equation!(scm, outcome_parents_pair) + end +end + +function add_equation!(scm::SCM, outcome_parents_pair) + outcome, parents = outcome_parents_pair outcome_symbol = Symbol(outcome) add_vertex!(scm, outcome_symbol) for parent in parents @@ -25,53 +35,26 @@ function add_equation!(scm::MetaGraph, outcome, parents) end end +function parents(scm, label) + code, _ = scm.vertex_properties[label] + return [scm.vertex_labels[parent_code] for parent_code in scm.graph.badjlist[code]] +end -get_outcome_extra_covariates(outcome_parents, treatment_and_confounders, vertex_labels) = - [vertex_labels[p] for p ∈ outcome_parents if p ∉ treatment_and_confounders] - -##################################################################### -### Identification Methods ### -##################################################################### - -abstract type AdjustmentMethod end +""" +A plate Structural Causal Model where: -struct BackdoorAdjustment <: AdjustmentMethod - outcome_extra::Bool -end +- For all outcomes: oᵢ = fᵢ(treatments, confounders, outcome_extra_covariates) +- For all treatments: tⱼ = fⱼ(confounders) -function identify(method::BackdoorAdjustment, estimand::T, scm::MetaGraph) where T<:Union{Cau} - # Treatment confounders - treatment_names = keys(estimand.treatment) - treatment_codes = [code_for(scm, treatment) for treatment ∈ treatment_names] - confounders_codes = scm.graph.badjlist[treatment_codes] - treatment_confounders = NamedTuple{treatment_names}( - [[scm.vertex_labels[w] for w in confounders_codes[i]] - for i in eachindex(confounders_codes)] - ) - # Extra covariates - outcome_extra_covariates = nothing - if method.outcome_extra_covariates - outcome_parents_codes = scm.graph.badjlist[code_for(scm, estimand.outcome)] - treatment_and_confounders_codes = Set(vcat(treatment_codes, confounders_codes...)) - outcome_extra_covariates = get_outcome_extra_covariates( - outcome_parents_codes, - treatment_and_confounders_codes, - scm.vertex_labels - ) - end +# Example - return T( - outcome=estimand.outcome, - treatment_values = estimand.treatment_values, - treatment_confounders = treatment_confounders, - outcome_extra_covariates = outcome_extra_covariates - ) +StaticSCM([:Y], [:T₁, :T₂], [:W₁, :W₂, :W₃]; outcome_extra_covariates=[:C]) +""" +function StaticSCM(outcomes, treatments, confounders; outcome_extra_covariates=()) + outcome_equations = (outcome => unique(vcat(treatments, confounders, outcome_extra_covariates)) for outcome in outcomes) + treatment_equations = (treatment => unique(confounders) for treatment in treatments) + return SCM(outcome_equations..., treatment_equations...) end -##################################################################### -### Causal Estimand ### -##################################################################### -struct CausalATE - outcome::Symbol - treatment_values::NamedTuple -end \ No newline at end of file +StaticSCM(;outcomes, treatments, confounders, outcome_extra_covariates=()) = + StaticSCM(outcomes, treatments, confounders; outcome_extra_covariates=outcome_extra_covariates) diff --git a/test/adjustment.jl b/test/adjustment.jl deleted file mode 100644 index afe51e7a..00000000 --- a/test/adjustment.jl +++ /dev/null @@ -1,28 +0,0 @@ -module TestAdjustment - -using Test -using TMLE - -@testset "Test BackdoorAdjustment" begin - scm = SCM( - SE(:Y, [:I, :V, :C]), - SE(:V, [:T, :K]), - SE(:T, [:W, :X]), - SE(:I, [:W, :G]), - ) - Ψ = CM(scm, treatment=(T=1,), outcome=:Y) - adjustment_method = BackdoorAdjustment() - @test TMLE.get_models_input_variables(adjustment_method, Ψ) == ( - T = [:W, :X], - Y = [:T, :W, :X] - ) - adjustment_method = BackdoorAdjustment(outcome_extra=[:C]) - @test TMLE.get_models_input_variables(adjustment_method, Ψ) == ( - T = [:W, :X], - Y = [:T, :W, :X, :C] - ) -end - -end - -true \ No newline at end of file diff --git a/test/counterfactual_mean_based/adjustment.jl b/test/counterfactual_mean_based/adjustment.jl new file mode 100644 index 00000000..688cbd29 --- /dev/null +++ b/test/counterfactual_mean_based/adjustment.jl @@ -0,0 +1,35 @@ +module TestAdjustment + +using Test +using TMLE + +@testset "Test BackdoorAdjustment" begin + scm = StaticSCM( + outcomes = [:Y₁, :Y₂], + treatments = [:T₁, :T₂], + confounders = [:W₁, :W₂], + outcome_extra_covariates = [:C] + ) + causal_estimands = [ + CM(outcome=:Y₁, treatment_values=(T₁=1,)), + ATE(outcome=:Y₁, treatment_values=(T₁=(case=1, control=0),)), + ATE(outcome=:Y₁, treatment_values=(T₁=(case=1, control=0), T₂=(case=1, control=0))), + IATE(outcome=:Y₁, treatment_values=(T₁=(case=1, control=0), T₂=(case=1, control=0))), + ] + method = BackdoorAdjustment(outcome_extra_covariates=[:C]) + statistical_estimands = [identify(method, estimand, scm) for estimand in causal_estimands] + + for (causal_estimand, statistical_estimand) in zip(causal_estimands, statistical_estimands) + @test statistical_estimand.outcome == causal_estimand.outcome + @test statistical_estimand.treatment_values == causal_estimand.treatment_values + @test statistical_estimand.outcome_extra_covariates == (:C,) + end + @test statistical_estimands[1].treatment_confounders == (T₁=(:W₁, :W₂),) + @test statistical_estimands[2].treatment_confounders == (T₁=(:W₁, :W₂),) + @test statistical_estimands[3].treatment_confounders == (T₁=(:W₁, :W₂), T₂=(:W₁, :W₂)) + @test statistical_estimands[4].treatment_confounders == (T₁=(:W₁, :W₂), T₂=(:W₁, :W₂)) +end + +end + +true \ No newline at end of file diff --git a/test/counterfactual_mean_based/estimands.jl b/test/counterfactual_mean_based/estimands.jl index 7aa0df7f..8d0ec93d 100644 --- a/test/counterfactual_mean_based/estimands.jl +++ b/test/counterfactual_mean_based/estimands.jl @@ -3,7 +3,7 @@ module TestEstimands using Test using TMLE -@testset "Test CMCompositeEstimand" begin +@testset "Test StatisticalCMCompositeEstimand" begin dataset = ( W = [1, 2, 3, 4, 5, 6, 7, 8], T₁ = ["A", "B", "A", "B", "A", "B", "A", "B"], @@ -81,7 +81,7 @@ using TMLE end @testset "Test structs are concrete types" begin - for type in Base.uniontypes(TMLE.CMCompositeEstimand) + for type in Base.uniontypes(TMLE.StatisticalCMCompositeEstimand) @test isconcretetype(type) end end diff --git a/test/estimands.jl b/test/estimands.jl index e239ea9a..f621a1c1 100644 --- a/test/estimands.jl +++ b/test/estimands.jl @@ -4,7 +4,7 @@ using TMLE using Test @testset "Test ConditionalDistribution" begin - distr = ConditionalDistribution("Y", ["C", 1, :A, ]) + distr = TMLE.ConditionalDistribution("Y", ["C", 1, :A, ]) @test distr.outcome === :Y @test distr.parents === (Symbol("1"), :A, :C) @test TMLE.variables(distr) == (:Y, Symbol("1"), :A, :C) @@ -12,16 +12,16 @@ end @testset "Test CMRelevantFactors" begin η = TMLE.CMRelevantFactors( - outcome_mean=ExpectedValue(:Y, [:T, :W]), - propensity_score=ConditionalDistribution(:T, [:W]) + outcome_mean=TMLE.ExpectedValue(:Y, [:T, :W]), + propensity_score=TMLE.ConditionalDistribution(:T, [:W]) ) @test TMLE.variables(η) == (:Y, :T, :W) η = TMLE.CMRelevantFactors( - outcome_mean=ExpectedValue(:Y, [:T, :W]), + outcome_mean=TMLE.ExpectedValue(:Y, [:T, :W]), propensity_score=( - ConditionalDistribution(:T₁, [:W₁]), - ConditionalDistribution(:T₂, [:W₂₁, :W₂₂]) + TMLE.ConditionalDistribution(:T₁, [:W₁]), + TMLE.ConditionalDistribution(:T₂, [:W₂₁, :W₂₂]) ) ) @test TMLE.variables(η) == (:Y, :T, :W, :T₁, :W₁, :T₂, :W₂₁, :W₂₂) diff --git a/test/estimators_and_estimates.jl b/test/estimators_and_estimates.jl index 94bf2f65..65eb6acb 100644 --- a/test/estimators_and_estimates.jl +++ b/test/estimators_and_estimates.jl @@ -12,7 +12,7 @@ verbosity = 1 n = 100 X, y = make_moons(n) dataset = DataFrame(Y=y, X₁=X.x1, X₂=X.x2) -estimand = ConditionalDistribution(:Y, [:X₁, :X₂]) +estimand = TMLE.ConditionalDistribution(:Y, [:X₁, :X₂]) fit_log = string("Estimating: ", TMLE.string_repr(estimand)) reuse_log = string("Reusing estimate for: ", TMLE.string_repr(estimand)) diff --git a/test/runtests.jl b/test/runtests.jl index 02c23051..63439c71 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,9 +7,8 @@ using Test @test include("missing_management.jl") @test include("composition.jl") @test include("treatment_transformer.jl") - # @test include("scm.jl") - # @test include("adjustment.jl") - + @test include("scm.jl") + @test include("counterfactual_mean_based/estimands.jl") @test include("counterfactual_mean_based/clever_covariate.jl") @test include("counterfactual_mean_based/gradient.jl") @@ -19,5 +18,5 @@ using Test @test include("counterfactual_mean_based/double_robustness_ate.jl") @test include("counterfactual_mean_based/double_robustness_iate.jl") @test include("counterfactual_mean_based/3points_interactions.jl") - + @test include("counterfactual_mean_based/adjustment.jl") end \ No newline at end of file diff --git a/test/scm.jl b/test/scm.jl index 1f4a4ba5..34c08bfd 100644 --- a/test/scm.jl +++ b/test/scm.jl @@ -2,96 +2,42 @@ module TestSCM using Test using TMLE -using MLJGLMInterface -using CategoricalArrays -using Random -using StableRNGs -using MLJBase - -@testset "Test SE" begin - # Incorrect Structural Equations - @test_throws TMLE.SelfReferringEquationError(:Y) SE(:Y, [:Y]) - # Correct Structural Equations - # - outcome is a Symbol - # - parents is a Set of Symbols - eq = SE("Y", [:T, "W"]) - @test TMLE.outcome(eq) == :Y - @test TMLE.parents(eq) == Set([:T, :W]) - @test TMLE.string_repr(eq) == "Y = f(T, W)" -end @testset "Test SCM" begin + # An empty SCM scm = SCM() - # setequation! - Yeq = SE(:Y, [:T, :W, :C]) - setequation!(scm, Yeq) - Teq = SE(:T, [:W]) - setequation!(scm, Teq) - # getequation - @test getequation(scm, :Y) == Yeq - @test getequation(scm, :T) == Teq - # set and get conditional_distribution with implied natural parents - cond_dist = set_conditional_distribution!(scm, :T, LinearBinaryClassifier()) - @test get_conditional_distribution(scm, :T) == cond_dist - # set and get conditional_distribution with provided parents - cond_dist = set_conditional_distribution!(scm, :Y, LinearBinaryClassifier()) - @test_throws KeyError get_conditional_distribution(scm, :Y, Set([:W, :T])) - cond_dist = TMLE.get_or_set_conditional_distribution_from_natural!(scm, :Y, Set([:W, :T])) - @test cond_dist.model == get_conditional_distribution(scm, :Y).model - # string representation - @test TMLE.string_repr(scm) == "Structural Causal Model:\n-----------------------\nT = f₁(W)\nY = f₂(T, W, C)\n" -end - -@testset "Test StaticConfoundedModel" begin - # With covariates - scm = StaticConfoundedModel(:Y, :T, [:W₁, :W₂], covariates=[:C₁]) - @test TMLE.parents(scm, :Y) == Set([:T, :W₁, :W₂, :C₁]) - @test TMLE.parents(scm, :T) == Set([:W₁, :W₂]) - @test get_conditional_distribution(scm, :T).model isa LinearBinaryClassifier - @test get_conditional_distribution(scm, :Y).model isa ProbabilisticPipeline - - # Without covariates - scm = StaticConfoundedModel(:Y, :T, :W₁, outcome_model=LinearBinaryClassifier(), treatment_model=LinearBinaryClassifier(fit_intercept=false)) - @test TMLE.parents(scm, :Y) == Set([:T, :W₁]) - @test TMLE.parents(scm, :T) == Set([:W₁]) - @test get_conditional_distribution(scm, :Y).model isa LinearBinaryClassifier - @test get_conditional_distribution(scm, :T).model.fit_intercept === false - - # With 1 covariate - scm = StaticConfoundedModel(:Y, :T, :W₁, covariates=:C₁) - @test TMLE.parents(scm, :Y) == Set([:T, :W₁, :C₁]) - @test TMLE.parents(scm, :T) == Set([:W₁]) - - # With multiple outcomes and treatments - scm = StaticConfoundedModel( - [:Y₁, :Y₂, :Y₃], - [:T₁, :T₂], - [:W₁, :W₂, :W₃], - covariates=[:C] - ) - - Yparents = Set([:T₁, :T₂, :W₁, :W₂, :W₃, :C]) - for Y in [:Y₁, :Y₂, :Y₃] - @test parents(scm, Y) == Yparents - @test get_conditional_distribution(scm, Y, Yparents).model isa ProbabilisticPipeline - end - - Tparents = Set([:W₁, :W₂, :W₃]) - for T in [:T₁, :T₂] - @test parents(scm, T) == Tparents - @test get_conditional_distribution(scm, T, Tparents).model isa LinearBinaryClassifier - end + add_equations!( + scm, + :Y => [:T₁, :T₂, :W₁, :W₂, :C], + :T₁ => [:W₁] + ) + add_equation!(scm, :T₂ => [:W₂]) + @test parents(scm, :T₁) == [:W₁] + @test parents(scm, :T₂) == [:W₂] + @test parents(scm, :Y) == [:T₁, :T₂, :W₁, :W₂, :C] + @test parents(scm, :C) == [] + # Another SCM + scm = SCM( + :Y₁ => [:T₁, :W₁], + :T₁ => [:W₁], + :Y₂ => [:T₁, :T₂, :W₁, :W₂, :C], + :T₂ => [:W₂] + ) + @test parents(scm, :Y₁) == [:T₁, :W₁] + @test parents(scm, :T₁) == [:W₁] + @test parents(scm, :Y₂) == [:T₁, :W₁, :T₂, :W₂, :C] + @test parents(scm, :T₂) == [:W₂] end -@testset "Test is_upstream" begin - scm = SCM( - SE(:T₁, [:W₁]), - SE(:T₂, [:W₂]), - SE(:Y, [:T₁, :T₂]), +@testset "Test StaticSCM" begin + scm = StaticSCM( + outcomes = [:Y₁, :Y₂], + treatments = [:T₁, :T₂, :T₃], + confounders = [:W₁, :W₂], + outcome_extra_covariates = [:C] ) - @test TMLE.is_upstream(:T₁, :Y, scm) === true - @test TMLE.is_upstream(:W₁, :Y, scm) === true - @test TMLE.is_upstream(:Y, :T₁, scm) === false + @test parents(scm, :Y₁) == parents(scm, :Y₂) == [:T₁, :T₂, :T₃, :W₁, :W₂, :C] + @test parents(scm, :T₁) == parents(scm, :T₂) == parents(scm, :T₃) == [:W₁, :W₂] end end From 26b84c9e43101dde8308876c5cde1995af450c9d Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 25 Sep 2023 15:49:18 +0100 Subject: [PATCH 086/151] add compat entries --- Project.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Project.toml b/Project.toml index fe015a7b..17953fd0 100644 --- a/Project.toml +++ b/Project.toml @@ -8,11 +8,13 @@ AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c" MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" +MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377" Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" @@ -40,4 +42,6 @@ TableOperations = "1.2" Tables = "1.6" YAML = "0.4" Zygote = "0.6" +Graphs = "1.8" +MetaGraphsNext = "0.6" julia = "1.6, 1.7, 1" From 2248e4a875ca0e8838d6c168b0eca15c128bcd71 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 25 Sep 2023 16:23:10 +0100 Subject: [PATCH 087/151] remove optimize ordering as the cache allows storage of multiple estimands now --- src/estimands.jl | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/estimands.jl b/src/estimands.jl index d486ef03..2bb4e05c 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -68,21 +68,6 @@ Function used to sort estimands for optimal estimation ordering. """ key(estimand::Estimand) = estimand -""" - optimize_ordering!(estimands::Vector{<:Estimand}) - -Optimizes the order of the `estimands` to maximize reuse of -fitted equations in the associated SCM. -""" -optimize_ordering!(estimands::Vector{<:Estimand}) = sort!(estimands, by=estimand_key) - -""" - optimize_ordering(estimands::Vector{<:Estimand}) - -See [`optimize_ordering!`](@ref) -""" -optimize_ordering(estimands::Vector{<:Estimand}) = sort(estimands, by=estimand_key) - ##################################################################### ### Conditional Distribution ### ##################################################################### From 44ea3ab4d434b875773c6f60f6f8786d106dc967 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 25 Sep 2023 16:25:50 +0100 Subject: [PATCH 088/151] only missing management to check --- README.md | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/README.md b/README.md index b73ba824..d9f10138 100644 --- a/README.md +++ b/README.md @@ -11,11 +11,4 @@ This package enables the estimation of various Causal Inference related estimand ## Notes to self -- For multiple TMLE steps, convergence when the mean of the IC is below: σ/n where σ is the standard deviation of the IC. -- Try make selectcols preserve input type (see MLJModel has a facility) -- Watch for discrepancy in train_validation_indices with missing data... -- Check that for IATEs simulations, the true effect size is actually an additive interaction and not a multiplicative one. -- clean imports in test files -- clean export list -- look into dataset management and missing values -- See about ordering of parameters \ No newline at end of file +- look into dataset management and missing values, Watch for discrepancy in train_validation_indices with missing data... \ No newline at end of file From 414ce754fb9917e59e33cf68684b057bca4d0e58 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 26 Sep 2023 12:19:04 +0100 Subject: [PATCH 089/151] add estimand to TMLE/OSE estimates --- src/counterfactual_mean_based/estimates.jl | 4 ++-- src/counterfactual_mean_based/estimators.jl | 4 ++-- test/composition.jl | 9 +++++++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/counterfactual_mean_based/estimates.jl b/src/counterfactual_mean_based/estimates.jl index c759a96c..55002d35 100644 --- a/src/counterfactual_mean_based/estimates.jl +++ b/src/counterfactual_mean_based/estimates.jl @@ -19,18 +19,18 @@ string_repr(estimate::MLCMRelevantFactors) = string( join((string_repr(f) for f in estimate.propensity_score), "\n- ") ) - ##################################################################### ### One Dimensional Estimates ### ##################################################################### - struct TMLEstimate{T<:AbstractFloat} <: Estimate + estimand::StatisticalCMCompositeEstimand Ψ̂::T IC::Vector{T} end struct OSEstimate{T<:AbstractFloat} <: Estimate + estimand::StatisticalCMCompositeEstimand Ψ̂::T IC::Vector{T} end diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 19177d7a..c1e00476 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -138,7 +138,7 @@ function (tmle::TMLEE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict() IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, targeted_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) verbosity >= 1 && @info "Done." # update!(cache, relevant_factors, targeted_factors_estimate) - return TMLEstimate(Ψ̂, IC), cache + return TMLEstimate(Ψ, Ψ̂, IC), cache end ##################################################################### @@ -176,7 +176,7 @@ function (estimator::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dic # Gradient and estimate IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) verbosity >= 1 && @info "Done." - return OSEstimate(Ψ̂ + mean(IC), IC), cache + return OSEstimate(Ψ, Ψ̂ + mean(IC), IC), cache end ##################################################################### diff --git a/test/composition.jl b/test/composition.jl index 31b4dd92..25acb08f 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -22,10 +22,15 @@ function make_dataset(;n=100) end @testset "Test cov" begin + Ψ = CM( + outcome = :Y, + treatment_values = (T=1,), + treatment_confounders = (T=[:W],) + ) n = 10 X = rand(n, 2) - ER₁ = TMLE.TMLEstimate(1., X[:, 1]) - ER₂ = TMLE.TMLEstimate(0., X[:, 2]) + ER₁ = TMLE.TMLEstimate(Ψ, 1., X[:, 1]) + ER₂ = TMLE.TMLEstimate(Ψ, 0., X[:, 2]) Σ = cov(ER₁, ER₂) @test size(Σ) == (2, 2) @test Σ == cov(X) From d9b9de8516d86b4b84b975c51cdeae552875bc41 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 26 Sep 2023 16:11:34 +0100 Subject: [PATCH 090/151] add some tests for missing management --- test/missing_management.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/missing_management.jl b/test/missing_management.jl index 7351d2f4..b601aa8a 100644 --- a/test/missing_management.jl +++ b/test/missing_management.jl @@ -24,6 +24,15 @@ function dataset_with_missing_and_ordered_treatment(;n=1000) end @testset "Test nomissing" begin + # If there is no missing data, this should be a no-op + dataset = DataFrame(rand(10, 3), :auto) + @test TMLE.nomissing(dataset) === dataset + filtered = TMLE.nomissing(dataset, [:x1, :x2]) + # The only cost to this operation is the creation of a new table but the columns are indistinguishable + @test filtered.x1 === dataset.x1 + @test filtered.x2 === dataset.x2 + + # Now with missing values dataset = dataset_with_missing_and_ordered_treatment(;n=100) # filter missing rows based on W column filtered = TMLE.nomissing(dataset, [:W]) From 35349f338a706e088166f01215da52bac307819f Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 25 Oct 2023 17:56:02 +0100 Subject: [PATCH 091/151] add some docstrings --- src/counterfactual_mean_based/estimators.jl | 54 ++++++++++++++++++++- src/estimators.jl | 2 +- 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index c1e00476..58741989 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -100,14 +100,41 @@ end ##################################################################### ### TMLE ### ##################################################################### + mutable struct TMLEE <: Estimator models::NamedTuple resampling::Union{Nothing, ResamplingStrategy} - ps_lowerbound::Float64 + ps_lowerbound::Union{Float64, Nothing} weighted::Bool tol::Union{Float64, Nothing} end +""" + TMLEE(models; resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing) + +Defines a TMLE estimator using the specified models for estimation of the nuisance parameters. The estimator is a +function that can be applied to estimate estimands for a dataset. + +# Arguments + +- models: A NamedTuple{variables}(models) where the `variables` are the outcome variables modeled by the `models`. +- resampling: Outer resampling strategy. Setting it to `nothing` (default) falls back to vanilla TMLE while +any valid `MLJ.ResamplingStrategy` will result in CV-TMLE. +- ps_lowerbound: Lowerbound for the propensity score to avoid division by 0. The special value `nothing` will +result in a data adaptive definition as described in [here](https://pubmed.ncbi.nlm.nih.gov/35512316/). +- weighted: Whether the fluctuation model is a classig GLM or a weighted version. The weighted fluctuation has +been show to be more robust to positivity violation in practice. +- tol: This is not used at the moment. + +# Example + +```julia +using MLJLinearModels +models = (Y = LinearRegressor(), T = LogisticClassifier()) +tmle = TMLEE(models) +Ψ̂ₙ, cache = tmle(Ψ, dataset) +``` +""" TMLEE(models; resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing) = TMLEE(models, resampling, ps_lowerbound, weighted, tol) @@ -148,9 +175,32 @@ end mutable struct OSE <: Estimator models::NamedTuple resampling::Union{Nothing, ResamplingStrategy} - ps_lowerbound::Float64 + ps_lowerbound::Union{Float64, Nothing} end +""" + OSE(models; resampling=nothing, ps_lowerbound=1e-8) + +Defines a One Step Estimator using the specified models for estimation of the nuisance parameters. The estimator is a +function that can be applied to estimate estimands for a dataset. + +# Arguments + +- models: A NamedTuple{variables}(models) where the `variables` are the outcome variables modeled by the `models`. +- resampling: Outer resampling strategy. Setting it to `nothing` (default) falls back to vanilla estimation while +any valid `MLJ.ResamplingStrategy` will result in CV-OSE. +- ps_lowerbound: Lowerbound for the propensity score to avoid division by 0. The special value `nothing` will +result in a data adaptive definition as described in [here](https://pubmed.ncbi.nlm.nih.gov/35512316/). + +# Example + +```julia +using MLJLinearModels +models = (Y = LinearRegressor(), T = LogisticClassifier()) +ose = OSE(models) +Ψ̂ₙ, cache = ose(Ψ, dataset) +``` +""" OSE(models; resampling=nothing, ps_lowerbound=1e-8) = OSE(models, resampling, ps_lowerbound) diff --git a/src/estimators.jl b/src/estimators.jl index 7c85bea7..02e0cf67 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -60,7 +60,7 @@ function (estimator::SampleSplitMLConditionalDistributionEstimator)(estimand, da relevant_dataset = TMLE.selectcols(dataset, TMLE.variables(estimand)) nfolds = size(estimator.train_validation_indices, 1) machines = Vector{Machine}(undef, nfolds) - # Fit Conditional DIstribution on each training split using MLJ + # Fit Conditional Distribution on each training split using MLJ for (index, (train_indices, _)) in enumerate(estimator.train_validation_indices) train_dataset = selectrows(relevant_dataset, train_indices) Xtrain = selectcols(train_dataset, estimand.parents) From 3007230a60f8cbfb2fb5df614a7e1ba659f68e2a Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 26 Oct 2023 11:17:13 +0100 Subject: [PATCH 092/151] update docs --- docs/src/index.md | 46 ++---- docs/src/user_guide/adjustment.md | 14 -- docs/src/user_guide/estimands.md | 94 ++++++++++-- docs/src/user_guide/estimation.md | 161 ++++++++++++-------- docs/src/user_guide/misc.md | 14 +- docs/src/user_guide/scm.md | 67 ++------ docs/src/walk_through.md | 17 +-- src/TMLE.jl | 5 +- src/counterfactual_mean_based/adjustment.jl | 4 + src/counterfactual_mean_based/estimators.jl | 17 ++- src/utils.jl | 9 ++ 11 files changed, 248 insertions(+), 200 deletions(-) delete mode 100644 docs/src/user_guide/adjustment.md diff --git a/docs/src/index.md b/docs/src/index.md index f7a675fa..fc8c4024 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -20,9 +20,7 @@ Pkg> add TMLE To run an estimation procedure, we need 3 ingredients: -1. A dataset -2. A Structural Causal Model that describes the relationship between the variables. -3. An estimand of interest +1. A dataset: here a simulation dataset. For illustration, assume we know the actual data generating process is as follows: @@ -54,51 +52,31 @@ dataset = (Y=Y, T=categorical(T), W=W) nothing # hide ``` -### Two lines TMLE +2. A quantity of interest: here the Average Treatment Effect (ATE). -Estimating the Average Treatment Effect can of ``T`` on ``Y`` can be as simple as: - -```@example quick-start -Ψ = ATE(outcome=:Y, treatment=(T=(case=true, control = false),), confounders=:W) -result, _ = tmle!(Ψ, dataset) -result -``` - -### Two steps approach - -Let's first define the Structural Causal Model: - -```@example quick-start -scm = StaticConfoundedModel(:Y, :T, :W) -``` - -and second, define the Average Treatment Effect of the treatment ``T`` on the outcome ``Y``: $ATE_{T:0 \rightarrow 1}(Y)$: +The Average Treatment Effect of ``T`` on ``Y`` confounded by ``W`` is defined as: ```@example quick-start Ψ = ATE( - scm, - outcome = :Y, - treatment = (T=(case=true, control = false),), + outcome=:Y, + treatment_values=(T=(case=true, control = false),), + treatment_confounders=(T=[:W],) ) ``` -Note that in this example the ATE can be computed exactly and is given by: - -```math -ATE_{0 \rightarrow 1}(P_0) = \mathbb{E}[1 + 3 - W] - \mathbb{E}[1] = 3 - \mathbb{E}[W] = 2.5 -``` - -Running the `tmle` will produce two asymptotically linear estimators: the TMLE and the One Step Estimator. For each we can look at the associated estimate, confidence interval and p-value: +3. An estimator: here a Targeted Maximum Likelihood Estimator (TMLE). ```@example quick-start -result, _ = tmle!(Ψ, dataset) +models = (Y=with_encoder(LinearRegressor()), T = LogisticClassifier()) +tmle = TMLEE(models) +result, _ = tmle(Ψ, dataset, verbosity=0); result ``` -and be comforted to see that our estimators covers the ground truth! 🥳 +We are comforted to see that our estimator covers the ground truth! 🥳 ```@example quick-start using Test # hide -@test pvalue(OneSampleTTest(result.tmle, 2.5)) > 0.05 # hide +@test pvalue(OneSampleTTest(result, 2.5)) > 0.05 # hide nothing # hide ``` diff --git a/docs/src/user_guide/adjustment.md b/docs/src/user_guide/adjustment.md deleted file mode 100644 index 712d924e..00000000 --- a/docs/src/user_guide/adjustment.md +++ /dev/null @@ -1,14 +0,0 @@ -```@meta -CurrentModule = TMLE -``` -# Adjustment Methods - -In a `SCM`, each variable is determined by a set of parents and a statistical model describing the functional relationship between them. However, for the estimation of Causal Estimands the fitted models may not exactly correspond to the variable's equation in the `SCM`. Adjustment methods tell the estimation procedure which input variables should be incorporated in the statistical model fits. - -## Backdoor Adjustment - -At the moment we provide a single adjustment method, namely the Backdoor adjustment method. The adjustment set consists of all the treatment variable's parents. Additional covariates used to fit the outcome model can be provided via `outcome_extra`. - -```julia -BackdoorAdjustment(;outcome_extra=[:C]) -``` diff --git a/docs/src/user_guide/estimands.md b/docs/src/user_guide/estimands.md index 412c0cae..e294ab38 100644 --- a/docs/src/user_guide/estimands.md +++ b/docs/src/user_guide/estimands.md @@ -14,7 +14,15 @@ Mathematically speaking, denoting the estimand by ``\Psi``, the set of all proba At the moment, most of the work in this package has been focused on estimands that are composite functions of the interventional conditional mean which is easily identified via backdoor adjustment and for which the efficient influence function is well known. -In what follows, ``P`` is a probability distribution generating an outcome ``Y``, a random vector of "treatment" variables ``\textbf{T}`` and a random vector of "confounding" variables ``\textbf{W}``. +In what follows, ``P`` is a probability distribution generating an outcome ``Y``, a random vector of "treatment" variables ``\textbf{T}`` and a random vector of "confounding" variables ``\textbf{W}``. For the examples, we will assume two treatment variables ``T₁`` and ``T₂`` taking either values 0 or 1. The ``SCM`` is given by: + +```@example estimands +scm = StaticSCM( + [:Y], + [:T₁, :T₂], + [:W]; +) +``` ## The Interventional Counterfactual Mean (CM) @@ -36,10 +44,26 @@ CM_{\textbf{t}}(P) = \mathbb{E}_{\textbf{W}}[\mathbb{E}[Y|do(\textbf{T}=\textbf{ - TMLE.jl Example -For a Structural Causal Model `scm`, an outcome `Y` and two treatments `T₁` and `T₂` taking values `t₁` and `t₂` respectively: +A causal estimand is given by: -```julia -Ψ = CM(scm, outcome=:Y, treatment=(T₁=t₁, T₂=t₂)) +```@example estimands +causalΨ = CM(outcome=:Y, treatment_values=(T₁=1, T₂=0)) +``` + +A corresponding statistical estimand can be identified via backdoor adjustment using the `scm`: + +```@example estimands +statisticalΨ = identify(Ψ, scm) +``` + +or defined directly: + +```@example estimands +statisticalΨ = CM( + outcome=:Y, + treatment_values=(T₁=1, T₂=0), + treatment_confounders=(T₁=[:W], T₂=[:W]) +) ``` ## The Average Treatment Effect @@ -65,16 +89,35 @@ ATE_{\textbf{t}_1 \rightarrow \textbf{t}_2}(P) &= CM_{\textbf{t}_2}(P) - CM_{\te - TMLE.jl Example -For a Structural Causal Model `scm`, an outcome `Y` and two treatments differences `T₁`:`t₁₁ → t₁₂` and `T₂`:`t₂₁ → t₂₂`: +A causal estimand is given by: -```julia -Ψ = ATE(scm, outcome=:Y, treatment=(T₁=(case=t₁₂, control=t₁₁), T₂=(case=t₂₂, control=t₂₁))) +```@example estimands +causalΨ = ATE( + outcome=:Y, + treatment_values=( + T₁=(case=1, control=0), + T₂=(case=1, control=0) + ) +) ``` -Note that all treatments are not required to change, for instance the following where `T₁` is held fixed at `t₁₁` is also a valid `ATE`: +A corresponding statistical estimand can be identified via backdoor adjustment using the `scm`: -```julia -Ψ = ATE(scm, outcome=:Y, treatment=(T₁=(case=t₁₁, control=t₁₁), T₂=(case=t₂₂, control=t₂₁))) +```@example estimands +statisticalΨ = identify(causalΨ, scm) +``` + +or defined directly: + +```@example estimands +statisticalΨ = ATE( + outcome=:Y, + treatment_values=( + T₁=(case=1, control=0), + T₂=(case=1, control=0) + ), + treatment_confounders=(T₁=[:W], T₂=[:W]) +) ``` ## The Interaction Average Treatment Effect @@ -103,10 +146,35 @@ IATE_{0 \rightarrow 1, 0 \rightarrow 1}(P) = \mathbb{E}_{\textbf{W}}[\mathbb{E}[ - TMLE.jl Example -For a Structural Causal Model `scm`, an outcome `Y` and two treatments differences `T₁`:`t₁₁ → t₁₂` and `T₂`:`t₂₁ → t₂₂`: +A causal estimand is given by: -```julia -Ψ = IATE(scm, outcome=:Y, treatment=(T₁=(case=t₁₂, control=t₁₁), T₂=(case=t₂₂, control=t₂₁))) +```@example estimands +causalΨ = IATE( + outcome=:Y, + treatment_values=( + T₁=(case=1, control=0), + T₂=(case=1, control=0) + ) +) +``` + +A corresponding statistical estimand can be identified via backdoor adjustment using the `scm`: + +```@example estimands +statisticalΨ = identify(causalΨ, scm) +``` + +or defined directly: + +```@example estimands +statisticalΨ = IATE( + outcome=:Y, + treatment_values=( + T₁=(case=1, control=0), + T₂=(case=1, control=0) + ), + treatment_confounders=(T₁=[:W], T₂=[:W]) +) ``` ## Any function of the previous Estimands diff --git a/docs/src/user_guide/estimation.md b/docs/src/user_guide/estimation.md index 3013a810..01d5d00a 100644 --- a/docs/src/user_guide/estimation.md +++ b/docs/src/user_guide/estimation.md @@ -44,101 +44,118 @@ function make_dataset(;n=1000) end dataset = make_dataset(n=10000) scm = SCM( - SE(:Y, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], with_encoder(LinearRegressor())), - SE(:T₁, [:W₁₁, :W₁₂], LogisticClassifier()), - SE(:T₂, [:W₂₁, :W₂₂], LogisticClassifier()), + :Y => [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], + :T₁ => [:W₁₁, :W₁₂], + :T₂ => [:W₂₁, :W₂₂], ) ``` -Once a `SCM` and an estimand have been defined, we can proceed with Targeted Estimation. This is done via the `tmle` function. Drawing from the example dataset and `SCM` from the Walk Through section, we can estimate the ATE for `T₁`. +Once a statistical estimand has been defined, we can proceed with estimation. At the moment, we provide 3 main types of estimators: + +- Targeted Maximum Likelihood Estimator (`TMLEE`) +- One-Step Estimator (`OSE`) +- Naive Plugin Estimator (`NAIVE`) + +Drawing from the example dataset and `SCM` from the Walk Through section, we can estimate the ATE for `T₁`. Let's use TMLE: ```@example estimation -Ψ₁ = ATE(scm, outcome=:Y, treatment=(T₁=(case=true, control=false),)) -result₁, fluctuation_mach = tmle!(Ψ₁, dataset; - adjustment_method=BackdoorAdjustment([:C]), - verbosity=1, - force=false, - ps_lowerbound=1e-8, - weighted_fluctuation=false +Ψ₁ = ATE( + outcome=:Y, + treatment_values=(T₁=(case=true, control=false),), + treatment_confounders=(T₁=[:W₁₁, :W₁₂],), + outcome_extra_covariates=[:C] ) +models = ( + Y=with_encoder(LinearRegressor()), + T₁=LogisticClassifier(), + T₂=LogisticClassifier(), +) +tmle = TMLEE(models) +result₁, cache = tmle(Ψ₁, dataset); +result₁ nothing # hide ``` We see that both models corresponding to variables `Y` and `T₁` were fitted in the process but that the model for `T₂` was not because it was not necessary to estimate this estimand. -The `fluctuation_mach` corresponds to the fitted machine that was used to fluctuate the initial fit. For instance, we can see what is the value of ``\epsilon`` corresponding to the clever covariate. +The `cache` contains estimates for the nuisance functions that were necessary to estimate the ATE. For instance, we can see what is the value of ``\epsilon`` corresponding to the clever covariate. ```@example estimation -ϵ = fitted_params(fluctuation_mach).coef[1] +ϵ = last_fluctuation_epsilon(cache) ``` -The `result` corresponds to the estimation result and contains 3 main elements: +The `result₁` structure corresponds to the estimation result and should report 3 main elements: -- The `TMLEEstimate` than can be accessed via: `tmle(result)`. -- The `OSEstimate` than can be accessed via: `ose(result)`. -- The naive initial estimate. +- A point estimate. +- A 95% confidence interval. +- A p-value (Corresponding to the test that the estimand is different than 0). -Since both the TMLE and OSE are asymptotically linear estimators, standard T tests from [HypothesisTests.jl](https://juliastats.org/HypothesisTests.jl/stable/) can be performed for each of them. +This is only summary statistics but since both the TMLE and OSE are asymptotically linear estimators, standard Z/T tests from [HypothesisTests.jl](https://juliastats.org/HypothesisTests.jl/stable/) can be performed. ```@example estimation -tmle_test_result = OneSampleTTest(tmle(result)) +tmle_test_result₁ = OneSampleTTest(result₁) ``` -We could now get an interest in the Average Treatment Effect of `T₂`: +We could now get an interest in the Average Treatment Effect of `T₂` that we will estimate with an `OSE`: ```@example estimation -Ψ₂ = ATE(scm, outcome=:Y, treatment=(T₂=(case=true, control=false),)) -result₂, fluctuation_mach = tmle!(Ψ₂, dataset; - adjustment_method=BackdoorAdjustment([:C]), - verbosity=1, - force=false, - ps_lowerbound=1e-8, - weighted_fluctuation=false +Ψ₂ = ATE( + outcome=:Y, + treatment_values=(T₂=(case=true, control=false),), + treatment_confounders=(T₂=[:W₂₁, :W₂₂],), + outcome_extra_covariates=[:C] ) +ose = OSE(models) +result₂, cache = ose(Ψ₂, dataset;cache=cache); +result₂ nothing # hide ``` -The model for `T₂` was fitted in the process but so was the model for `Y` 🤔. This is because the `BackdoorAdjustment` method determined that the set of inputs for `Y` were different in both cases. +Again, required nuisance functions are fitted and stored in the cache. ## Reusing the SCM -Let's now see how the models can be reused with a new estimand, say the Total Average Treatment Effecto of both `T₁` and `T₂`. +Let's now see how the `cache` can be reused with a new estimand, say the Total Average Treatment Effect of both `T₁` and `T₂`. ```@example estimation -Ψ₃ = ATE(scm, outcome=:Y, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false))) -result₃, fluctuation_mach = tmle!(Ψ₃, dataset; - adjustment_method=BackdoorAdjustment([:C]), - verbosity=1, - force=false, - ps_lowerbound=1e-8, - weighted_fluctuation=false +Ψ₃ = ATE( + outcome=:Y, + treatment_values=( + T₁=(case=true, control=false), + T₂=(case=true, control=false) + ), + treatment_confounders=( + T₁=[:W₁₁, :W₁₂], + T₂=[:W₂₁, :W₂₂], + ), + outcome_extra_covariates=[:C] ) +result₃, cache = tmle(Ψ₃, dataset; cache=cache); +result₃ nothing # hide ``` -This time only the statistical model for `Y` is fitted again while reusing the models for `T₁` and `T₂`. Finally, let's see what happens if we estimate the `IATE` between `T₁` and `T₂`. +This time only the model for `Y` is fitted again while reusing the models for `T₁` and `T₂`. Finally, let's see what happens if we estimate the `IATE` between `T₁` and `T₂`. ```@example estimation -Ψ₄ = IATE(scm, outcome=:Y, treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false))) -result₄, fluctuation_mach = tmle!(Ψ₄, dataset; - adjustment_method=BackdoorAdjustment([:C]), - verbosity=1, - force=false, - ps_lowerbound=1e-8, - weighted_fluctuation=false +Ψ₄ = IATE( + outcome=:Y, + treatment_values=( + T₁=(case=true, control=false), + T₂=(case=true, control=false) + ), + treatment_confounders=( + T₁=[:W₁₁, :W₁₂], + T₂=[:W₂₁, :W₂₂], + ), + outcome_extra_covariates=[:C] ) +result₄, cache = tmle(Ψ₄, dataset; cache=cache); +result₄ nothing # hide ``` -All statistical models have been reused 😊! - -## Ordering the estimands - -Given a vector of estimands, a clever ordering can be obtained via the `optimize_ordering/optimize_ordering!` functions. - -```@example estimation -optimize_ordering([Ψ₃, Ψ₁, Ψ₂, Ψ₄]) == [Ψ₁, Ψ₃, Ψ₄, Ψ₂] -``` +All nuisance functions have been reused, only the fluctuation is fitted! ## Composing Estimands @@ -151,25 +168,37 @@ IATE_{T_1=0 \rightarrow 1, T_2=0 \rightarrow 1} = ATE_{T_1=0 \rightarrow 1, T_2= ``` ```@example estimation -first_ate = ATE(scm, outcome=:Y, treatment=(T₁=(case=true, control=false), T₂=(case=false, control=false))) -first_ate_result, _ = tmle!(first_ate, dataset) - -second_ate = ATE(scm, outcome=:Y, treatment=(T₁=(case=false, control=false), T₂=(case=true, control=false))) -second_ate_result, _ = tmle!(second_ate, dataset) +first_ate = ATE( + outcome=:Y, + treatment_values=( + T₁=(case=true, control=false), + T₂=(case=false, control=false)), + treatment_confounders=( + T₁=[:W₁₁, :W₁₂], + T₂=[:W₂₁, :W₂₂], + ), +) +first_ate_result, cache = tmle(first_ate, dataset, cache=cache, verbosity=0); + +second_ate = ATE( + outcome=:Y, + treatment_values=( + T₁=(case=false, control=false), + T₂=(case=true, control=false)), + treatment_confounders=( + T₁=[:W₁₁, :W₁₂], + T₂=[:W₂₁, :W₂₂], + ), + ) +second_ate_result, cache = tmle(second_ate, dataset, cache=cache, verbosity=0); composed_iate_result = compose( (x, y, z) -> x - y - z, - tmle(result₃), tmle(first_ate_result), tmle(second_ate_result) + result₃, first_ate_result, second_ate_result ) isapprox( - estimate(tmle(result₄)), + estimate(result₄), estimate(composed_iate_result), atol=0.1 ) ``` - -## Weighted Fluctuation - -It has been reported that, in settings close to positivity violation (some treatments' values are very rare) TMLE may be unstable. This has been shown to be stabilized by fitting a weighted fluctuation model instead and by slightly modifying the clever covariate to keep things mathematically sound. - -This is implemented in TMLE.jl and can be turned on by selecting `weighted_fluctuation=true` in the `tmle` function. diff --git a/docs/src/user_guide/misc.md b/docs/src/user_guide/misc.md index ed47dd5c..9676712d 100644 --- a/docs/src/user_guide/misc.md +++ b/docs/src/user_guide/misc.md @@ -1,10 +1,22 @@ # Miscellaneous +## Adjustment Methods + +An adjustment method is a function that transforms a causal estimand into a statistical estimand using an associated `SCM`. At the moment, the only available adjustment method is the backdoor adjustment. + +### Backdoor Adjustment + +The adjustment set consists of all the treatment variable's parents. Additional covariates used to fit the outcome model can be provided via `outcome_extra`. + +```julia +BackdoorAdjustment(;outcome_extra_covariates=[:C]) +``` + ## Treatment Transformer To account for the fact that treatment variables are categorical variables we provide a MLJ compliant transformer that will either: -- Retrieve the floating point representation of a treatment it it has a natural ordering +- Retrieve the floating point representation of a treatment if it has a natural ordering - One hot encode it otherwise Such transformer can be created with: diff --git a/docs/src/user_guide/scm.md b/docs/src/user_guide/scm.md index 59a15354..641e580e 100644 --- a/docs/src/user_guide/scm.md +++ b/docs/src/user_guide/scm.md @@ -4,7 +4,7 @@ CurrentModule = TMLE # Structural Causal Models -In TMLE.jl, everything starts from the definition of a Structural Causal Model (`SCM`). A `SCM` in a series of Structural Equations (`SE`) that describe the causal relationships between the random variables under study. The purpose of this package is not to infer the `SCM`, instead we assume it is provided by the user. There are multiple ways one can define a `SCM` that we now describe. +Even if you don't have to, it can be useful to define a Structural Causal Model (`SCM`) for your problem. A `SCM` is a directed acyclic graph that describes the causal relationships between the random variables under study. ## Incremental Construction @@ -18,31 +18,13 @@ scm = SCM() This model does not say anything about the random variables and is thus not really useful. Let's assume that we are interested in an outcome ``Y`` and that this outcome is determined by 8 other random variables. We can add this assumption to the model ```@example scm -push!(scm, SE(:Y, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :W, :C])) +add_equation!(scm, :Y => [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :W, :C]) ``` -At this point, we haven't made any assumption regarding the functional form of the relationship between ``Y`` and its parents. We can add a further assumption by setting a statistical model for ``Y``, suppose we know it is generated from a logistic model, we can make that explicit: - -```@example scm -using MLJLinearModels -setmodel!(scm.Y, LogisticClassifier()) -``` - ---- -ℹ️ **Note on Models** - -- TMLE.jl is based on the main machine-learning framework in Julia: [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/). As such, any model respecting the MLJ interface is a valid model in TMLE.jl. -- In real world scenarios, we usually don't know what is the true statistical model for each variable and want to keep it as large as possible. For this reason it is recommended to use Super-Learning which is implemented in MLJ by the [Stack](https://alan-turing-institute.github.io/MLJ.jl/dev/model_stacking/#Model-Stacking) and comes with theoretical properties. -- In the dataset, treatment variables are represented with [categorical data](https://alan-turing-institute.github.io/MLJ.jl/dev/working_with_categorical_data/). This means the models that depend on such variables will need to properly deal with them. For this purpose we provide a `TreatmentTransformer` which can easily be combined with any `model` in a [Pipelining](https://alan-turing-institute.github.io/MLJ.jl/dev/linear_pipelines/) flavour with `with_encoder(model)`. -- The `SCM` has no knowledge of the data and thus cannot verify that the assumed statistical model is compatible with the data. This is done at a later stage. - ---- - Let's now assume that we have a more complete knowledge of the problem and we also know how `T₁` and `T₂` depend on the rest of the variables in the system. ```@example scm -push!(scm, SE(:T₁, [:W₁₁, :W₁₂, :W], model=LogisticClassifier())) -push!(scm, SE(:T₂, [:W₂₁, :W₂₂, :W])) +add_equations!(scm, :T₁ => [:W₁₁, :W₁₂, :W], :T₂ => [:W₂₁, :W₂₂, :W]) ``` ## One Step Construction @@ -51,46 +33,23 @@ Instead of constructing the `SCM` incrementally, one can provide all the specifi ```@example scm scm = SCM( - SE(:Y, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :W, :C], with_encoder(LinearRegressor())), - SE(:T₁, [:W₁₁, :W₁₂, :W], model=LogisticClassifier()), - SE(:T₂, [:W₂₁, :W₂₂, :W]), + :Y => [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :W, :C], + :T₁ => [:W₁₁, :W₁₂, :W], + :T₂ => [:W₂₁, :W₂₂, :W], ) ``` -Noting that we have used the `with_encoder` function to reflect the fact that we know that `T₁` and `T₂` are categorical variables. - ## Classic Structural Causal Models -There are many cases where we are interested in estimating the causal effect of a single treatment variable on a single outcome. Because it is typically only necessary to adjust for backdoor variables in order to identify this causal effect, we provide the `StaticConfoundedModel` interface to build such `SCM`: - -```@example scm -scm = StaticConfoundedModel( - :Y, :T, [:W₁, :W₂]; - covariates=[:C], - outcome_model = with_encoder(LinearRegressor()), - treatment_model = LogisticClassifier() -) -``` - -The optional `covariates` are variables that influence the outcome but are not confounding the treatment. - -This model can be extended to a plate-model with multiple treatments and multiple outcomes. In this case the set of confounders is assumed to confound all treatments which are in turn assumed to impact all outcomes. This can be defined as: +There are many cases where we are interested in estimating the causal effect of a some treatment variables on a some outcome variables. If all treatment variables share the same set of confounders, we can quickly define the associated `SCM` with the `StaticSCM` interface: ```@example scm -scm = StaticConfoundedModel( - [:Y₁, :Y₂], [:T₁, :T₂], [:W₁, :W₂]; - covariates=[:C], - outcome_model = with_encoder(LinearRegressor()), - treatment_model = LogisticClassifier() +scm = StaticSCM( + [:Y₁, :Y₂], + [:T₁, :T₂], + [:W₁, :W₂]; + outcome_extra_covariates=[:C], ) ``` -More classic `SCM` may be added in the future based on needs. - -## Fitting the SCM - -It is usually not necessary to fit an entire `SCM` in order to estimate causal estimands of interest. Instead only some components are required and will be automatically determined (see [Estimation](@ref)). However, if you like, you can fit all the equations for which statistical models have been provided against a dataset: - -```julia -fit!(scm, dataset) -``` +where `outcome_extra_covariates` is a set of extra variables that are causal of the outcomes but are not of direct interest in the study. diff --git a/docs/src/walk_through.md b/docs/src/walk_through.md index 9ae4f66c..acac696a 100644 --- a/docs/src/walk_through.md +++ b/docs/src/walk_through.md @@ -66,24 +66,15 @@ The modeling stage starts from the definition of a Structural Causal Model (`SCM ```@example walk-through scm = SCM( - SE(:Y, [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], with_encoder(LinearRegressor())), - SE(:T₁, [:W₁₁, :W₁₂], LogisticClassifier()), - SE(:T₂, [:W₂₁, :W₂₂], LogisticClassifier()), + :Y => [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], + :T₁ => [:W₁₁, :W₁₂], + :T₂ => [:W₂₁, :W₂₂] ) ``` ---- -**NOTE** - -- Each Structural Equation specifies a child node, its parents and the assumed relationship between them. Here we know the model class from which each variable has been generated but in practice this is usually not the case. Instead we recommend the use of Super Learning / Stacking (see TODO). -- Because the treatment variables are categorical, they need to be encoded to be digested by the downstream models, `with_encoder(model)` simply creates a [Pipeline](https://alan-turing-institute.github.io/MLJ.jl/dev/linear_pipelines/#Linear-Pipelines) equivalent to `TreatmentEncoder() |> model`. -- At this point, the `SCM` has no knowledge about the dataset and cannot verify that the models are compatible with the actual data. - ---- - ## The Estimands -From the previous causal model we can ask multiple causal questions which are each represented by a distinct estimand. The set of available estimands types can be listed as follow: +From the previous causal model we can ask multiple causal questions which are each represented by distinct estimands. The set of available estimands types can be listed as follow: ```@example walk-through AVAILABLE_ESTIMANDS diff --git a/src/TMLE.jl b/src/TMLE.jl index b6a80723..ac236d5e 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -34,9 +34,10 @@ export optimize_ordering, optimize_ordering! export TMLEE, OSE, NAIVE export var, estimate, OneSampleTTest, OneSampleZTest, pvalue, confint export compose -export TreatmentTransformer, with_encoder +export TreatmentTransformer, with_encoder, encoder export BackdoorAdjustment, identify - +export last_fluctuation_epsilon + # ############################################################################# # INCLUDES # ############################################################################# diff --git a/src/counterfactual_mean_based/adjustment.jl b/src/counterfactual_mean_based/adjustment.jl index 79e74a3c..f9c53b6e 100644 --- a/src/counterfactual_mean_based/adjustment.jl +++ b/src/counterfactual_mean_based/adjustment.jl @@ -18,6 +18,10 @@ function statistical_type_from_causal_type(T) return eval(Symbol(new_typestring)) end +identify(causal_estimand::T, scm::MetaGraph; method=BackdoorAdjustment()::BackdoorAdjustment) where T<:CausalCMCompositeEstimands = + identify(method, causal_estimand, scm) + + function identify(method::BackdoorAdjustment, causal_estimand::T, scm::MetaGraph) where T<:CausalCMCompositeEstimands # Treatment confounders treatment_names = keys(causal_estimand.treatment_values) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 58741989..f032529a 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -142,7 +142,6 @@ function (tmle::TMLEE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict() # Check the estimand against the dataset TMLE.check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's relevant factors - verbosity >= 1 && @info "Fitting the required equations..." relevant_factors = TMLE.get_relevant_factors(Ψ) nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(relevant_factors)) initial_factors_dataset = TMLE.choose_initial_dataset(dataset, nomissing_dataset, tmle.resampling) @@ -208,7 +207,6 @@ function (estimator::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dic # Check the estimand against the dataset TMLE.check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's relevant factors - verbosity >= 1 && @info "Fitting the required equations..." initial_factors = TMLE.get_relevant_factors(Ψ) nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(initial_factors)) initial_factors_dataset = TMLE.choose_initial_dataset(dataset, nomissing_dataset, estimator.resampling) @@ -241,7 +239,6 @@ function (estimator::NAIVE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=D # Check the estimand against the dataset TMLE.check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's relevant factors - verbosity >= 1 && @info "Fitting the required equations..." relevant_factors = TMLE.get_relevant_factors(Ψ) nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(relevant_factors)) outcome_mean_estimate = MLConditionalDistributionEstimator(estimator.model)( @@ -252,4 +249,18 @@ function (estimator::NAIVE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=D ) Ψ̂ = mean(counterfactual_aggregate(Ψ, outcome_mean_estimate, nomissing_dataset)) return Ψ̂, cache +end + +##################################################################### +### Causal Estimand Estimation ### +##################################################################### + + +function (estimator::Union{NAIVE, OSE, TMLEE})(causalΨ::CausalCMCompositeEstimands, dataset; + identification_method=BackdoorAdjustment(), + cache=Dict(), + verbosity=1 + ) + Ψ = identify(identification_method, causalΨ, scm) + return estimator(Ψ, dataset; cache=cache, verbosity=verbosity) end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index c03a8422..9f9807bd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -58,4 +58,13 @@ function counterfactualTreatment(vals, T) NamedTuple{Tnames}( [categorical(repeat([vals[i]], n), levels=levels(Tables.getcolumn(T, name)), ordered=isordered(Tables.getcolumn(T, name))) for (i, name) in enumerate(Tnames)]) +end + + +last_fluctuation(cache) = cache[:last_fluctuation] + +function last_fluctuation_epsilon(cache) + mach = TMLE.last_fluctuation(cache).outcome_mean.machine + fp = fitted_params(fitted_params(mach).fitresult.one_dimensional_path) + return fp.coef end \ No newline at end of file From 02d42b779d02cb1207eb5ebf4a052f65706fdb8f Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 30 Oct 2023 19:21:40 +0000 Subject: [PATCH 093/151] fix some more docs --- README.md | 6 +- docs/make.jl | 16 ++--- docs/src/walk_through.md | 74 +++++++++++---------- src/counterfactual_mean_based/adjustment.jl | 2 +- src/counterfactual_mean_based/estimands.jl | 4 +- src/counterfactual_mean_based/estimators.jl | 2 +- 6 files changed, 53 insertions(+), 51 deletions(-) diff --git a/README.md b/README.md index d9f10138..9fdb5f4d 100644 --- a/README.md +++ b/README.md @@ -7,8 +7,4 @@ This package enables the estimation of various Causal Inference related estimands using Targeted Minimum Loss-Based Estimation (TMLE). TMLE.jl is based on top of [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/), which means any MLJ compliant machine-learning model can be used here. -**New to TMLE.jl?** [Get started now](https://targene.github.io/TMLE.jl/stable/) - -## Notes to self - -- look into dataset management and missing values, Watch for discrepancy in train_validation_indices with missing data... \ No newline at end of file +**New to TMLE.jl?** [Get started now](https://targene.github.io/TMLE.jl/stable/) \ No newline at end of file diff --git a/docs/make.jl b/docs/make.jl index af708a93..5500fc44 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -8,11 +8,11 @@ DocMeta.setdocmeta!(TMLE, :DocTestSetup, :(using TMLE); recursive=true) @info "Building Literate pages..." examples_dir = joinpath(@__DIR__, "../examples") build_examples_dir = joinpath(@__DIR__, "src", "examples/") -for file in readdir(examples_dir) - if endswith(file, "jl") - Literate.markdown(joinpath(examples_dir, file), build_examples_dir;documenter=true) - end -end +# for file in readdir(examples_dir) +# if endswith(file, "jl") +# Literate.markdown(joinpath(examples_dir, file), build_examples_dir;documenter=true) +# end +# end @info "Running makedocs..." makedocs(; @@ -29,11 +29,11 @@ makedocs(; "Home" => "index.md", "Walk Through" => "walk_through.md", "User Guide" => [joinpath("user_guide", f) for f in - ("scm.md", "estimands.md", "estimation.md", "adjustment.md", "misc.md")], + ("scm.md", "estimands.md", "estimation.md", "misc.md")], "Examples" => [ # joinpath("examples", "introduction_to_targeted_learning.md"), - joinpath("examples", "super_learning.md"), - joinpath("examples", "double_robustness.md") + # joinpath("examples", "super_learning.md"), + # joinpath("examples", "double_robustness.md") ], "Resources" => "resources.md", "API Reference" => "api.md" diff --git a/docs/src/walk_through.md b/docs/src/walk_through.md index acac696a..fe03cdd7 100644 --- a/docs/src/walk_through.md +++ b/docs/src/walk_through.md @@ -72,38 +72,38 @@ scm = SCM( ) ``` -## The Estimands +## The Causal Estimands -From the previous causal model we can ask multiple causal questions which are each represented by distinct estimands. The set of available estimands types can be listed as follow: +From the previous causal model we can ask multiple causal questions which are each represented by distinct causal estimands. The set of available estimands types can be listed as follow: ```@example walk-through AVAILABLE_ESTIMANDS ``` -At the moment there are 3 main estimand types we can estimate in TMLE.jl, we provide below a few examples. +At the moment there are 3 main causal estimands in TMLE.jl, we provide below a few examples. -- The Interventional Counterfactual Mean: +- The Counterfactual Mean: ```@example walk-through cm = CM( - scm, - outcome=:Y, - treatment=(T₁=1,) - ) + outcome = :Y, + treatment_values = (T₁=true,) +) ``` - The Average Treatment Effect: ```@example walk-through total_ate = ATE( - scm, - outcome=:Y, - treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)) + outcome = :Y, + treatment_values = ( + T₁=(case=1, control=0), + T₂=(case=1, control=0) + ) ) marginal_ate_t1 = ATE( - scm, - outcome=:Y, - treatment=(T₁=(case=1, control=0),) + outcome = :Y, + treatment_values = (T₁=(case=1, control=0),) ) ``` @@ -111,52 +111,56 @@ marginal_ate_t1 = ATE( ```@example walk-through iate = IATE( - scm, - outcome=:Y, - treatment=(T₁=(case=1, control=0), T₂=(case=1, control=0)) + outcome = :Y, + treatment_values = ( + T₁=(case=1, control=0), + T₂=(case=1, control=0) + ) ) ``` -## Targeted Estimation +## Identification -Then each parameter can be estimated by calling the `tmle!` function. For example: +Identification is the process by which a Causal Estimand is turned into a Statistical Estimand, that is, a quantity we may estimate from data. This is done via the `identify` function which also takes in the ``SCM``: ```@example walk-through -result, _ = tmle!(cm, dataset) -result +statistical_iate = identify(iate, scm) ``` -The `result` contains 3 main elements: +Alternatively, you can also directly define the statistical parameters (see [Estimands](@ref)). -- The `TMLEEstimate` than can be accessed via: +## Estimation -```@example walk-through -tmle(result) -``` - -- The `OSEstimate` than can be accessed via: +Then each parameter can be estimated by building an estimator (which is simply a function) and calling it on data. For illustration, we will keep the models used simple linear models. We define a TMLE: ```@example walk-through -ose(result) +models = ( + Y = with_encoder(LinearRegressor()), + T₁ = LogisticClassifier(), + T₂ = LogisticClassifier() +) +tmle = TMLEE(models) ``` -- The naive initial estimate. +And estimate the `cm` causal parameter. Because we haven't identified the parameter, we need to provide the `scm` as well: ```@example walk-through -naive(result) +result, cache = tmle(cm, scm, dataset); +result ``` -The adjustment set is determined by the provided `adjustment_method` keyword. At the moment, only `BackdoorAdjustment` is available. However one can specify that extra covariates could be used to fit the outcome model. +We can estimate statistical parameters directly, let's use the One-Step estimator: ```@example walk-through -result, _ = tmle!(iate, dataset;adjustment_method=BackdoorAdjustment([:C])) +ose = OSE(models) +result, cache = ose(statistical_iate, dataset) result ``` ## Hypothesis Testing -Because the TMLE and OSE are asymptotically linear estimators, they asymptotically follow a Normal distribution. This means one can perform standard T tests of null hypothesis. TMLE.jl extends the method provided by the [HypothesisTests.jl](https://juliastats.org/HypothesisTests.jl/stable/) package that can be used as follows. +Because the TMLE and OSE are asymptotically linear estimators, they asymptotically follow a Normal distribution. This means one can perform standard T/Z tests of null hypothesis. TMLE.jl extends the method provided by the [HypothesisTests.jl](https://juliastats.org/HypothesisTests.jl/stable/) package that can be used as follows. ```@example walk-through -OneSampleTTest(tmle(result)) +OneSampleTTest(result) ``` diff --git a/src/counterfactual_mean_based/adjustment.jl b/src/counterfactual_mean_based/adjustment.jl index f9c53b6e..9fbca3b4 100644 --- a/src/counterfactual_mean_based/adjustment.jl +++ b/src/counterfactual_mean_based/adjustment.jl @@ -9,7 +9,7 @@ struct BackdoorAdjustment <: AdjustmentMethod BackdoorAdjustment(outcome_extra_covariates) = new(unique_sorted_tuple(outcome_extra_covariates)) end -BackdoorAdjustment(;outcome_extra_covariates=false) = +BackdoorAdjustment(;outcome_extra_covariates=()) = BackdoorAdjustment(outcome_extra_covariates) function statistical_type_from_causal_type(T) diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index 8992ecc7..761948a4 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -33,7 +33,7 @@ variables(estimand::CMRelevantFactors) = ### Functionals ### ##################################################################### -ESTIMANDS_DOCS = Dict( +const ESTIMANDS_DOCS = Dict( :CM => (formula="``CM(Y, T=t) = E[Y|do(T=t)]``",), :ATE => (formula="``ATE(Y, T, case, control) = E[Y|do(T=case)] - E[Y|do(T=control)``",), :IATE => (formula="``IATE = E[Y|do(T₁=1, T₂=1)] - E[Y|do(T₁=1, T₂=0)] - E[Y|do(T₁=0, T₂=1)] + E[Y|do(T₁=0, T₂=0)]``",) @@ -90,6 +90,8 @@ CausalCMCompositeEstimands = Union{(eval(Symbol(:Causal, x)) for x in keys(ESTIM StatisticalCMCompositeEstimand = Union{(eval(Symbol(:Statistical, x)) for x in keys(ESTIMANDS_DOCS))...} +const AVAILABLE_ESTIMANDS = [x[1] for x ∈ ESTIMANDS_DOCS] + indicator_fns(Ψ::StatisticalCM) = Dict(values(Ψ.treatment_values) => 1.) function indicator_fns(Ψ::StatisticalATE) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index f032529a..636e978c 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -256,7 +256,7 @@ end ##################################################################### -function (estimator::Union{NAIVE, OSE, TMLEE})(causalΨ::CausalCMCompositeEstimands, dataset; +function (estimator::Union{NAIVE, OSE, TMLEE})(causalΨ::CausalCMCompositeEstimands, scm, dataset; identification_method=BackdoorAdjustment(), cache=Dict(), verbosity=1 From 1b8d4a9d82c89ad8a0faceff37f2f1cf8d3f3298 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 30 Oct 2023 20:41:26 +0000 Subject: [PATCH 094/151] update docs --- docs/Project.toml | 2 +- docs/src/resources.md | 2 +- docs/src/user_guide/estimands.md | 16 +++++++--------- docs/src/walk_through.md | 14 +++++++------- src/counterfactual_mean_based/estimands.jl | 3 --- 5 files changed, 16 insertions(+), 21 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index da78bb9f..e7c9a2e9 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -24,7 +24,7 @@ CSV = "0.10.11" CairoMakie = "0.10.4" DataFrames = "1.5" Distributions = "0.25" -Documenter = "0.27" +Documenter = "1" EvoTrees = "0.16" Literate = "2.13" MLJ = "0.19.1" diff --git a/docs/src/resources.md b/docs/src/resources.md index b18c30dd..ea3542cd 100644 --- a/docs/src/resources.md +++ b/docs/src/resources.md @@ -1,6 +1,6 @@ # Resources -Targeted Learning is a difficult topic, while it is not strictly necessary to understand the details to use this package it can certainly help. Here is an incomplete list of external ressources that I found useful in my personnal search for enlightenment. +Targeted Learning is a difficult topic, while it is not strictly necessary to understand the details to use this package it can certainly help. Here is an incomplete list of external resources that I found useful in my journey. ## Websites diff --git a/docs/src/user_guide/estimands.md b/docs/src/user_guide/estimands.md index e294ab38..28e0ec5e 100644 --- a/docs/src/user_guide/estimands.md +++ b/docs/src/user_guide/estimands.md @@ -4,15 +4,13 @@ CurrentModule = TMLE # Estimands -Most causal questions can be translated into a causal estimand. Usually either an interventional or counterfactual quantity. What would have been the outcome if I had set this variable to this value? When identified, this causal estimand translates to a statistical estimand which can be estimated from data. For us, an estimand will be a functional, that is a function that takes as input a probability distribution, and outputs a real number or vector of real numbers. - -Mathematically speaking, denoting the estimand by ``\Psi``, the set of all probability distributions by ``\mathcal{M}``: +Most causal questions can be translated into a causal estimand. Usually either an interventional or counterfactual quantity. What would have been the outcome if I had set this variable to this value? When identified, this causal estimand translates to a statistical estimand which can be estimated from data. From a mathematical standpoint, an estimand (``\Psi``) is a functional, that is a function that takes as input a probability distribution (from a model ``\mathcal{M}``), and outputs a real number or vector of real numbers. ```math \Psi: \mathcal{M} \rightarrow \mathbb{R}^p ``` -At the moment, most of the work in this package has been focused on estimands that are composite functions of the interventional conditional mean which is easily identified via backdoor adjustment and for which the efficient influence function is well known. +At the moment, most of the work in this package has been focused on estimands that are composite functions of the counterfactual mean which is easily identified via backdoor adjustment and for which the gradient is well known. In what follows, ``P`` is a probability distribution generating an outcome ``Y``, a random vector of "treatment" variables ``\textbf{T}`` and a random vector of "confounding" variables ``\textbf{W}``. For the examples, we will assume two treatment variables ``T₁`` and ``T₂`` taking either values 0 or 1. The ``SCM`` is given by: @@ -24,11 +22,11 @@ scm = StaticSCM( ) ``` -## The Interventional Counterfactual Mean (CM) +## The Counterfactual Mean (CM) - Causal Question: -What would be the mean of ``Y`` in the population if we intervened on ``\textbf{T}`` and set it to ``\textbf{t}``? +What would have been the mean of ``Y`` had we set ``\textbf{T}=\textbf{t}``? - Causal Estimand: @@ -39,7 +37,7 @@ CM_{\textbf{t}}(P) = \mathbb{E}[Y|do(\textbf{T}=\textbf{t})] - Statistical Estimand (via backdoor adjustment): ```math -CM_{\textbf{t}}(P) = \mathbb{E}_{\textbf{W}}[\mathbb{E}[Y|do(\textbf{T}=\textbf{t}), \textbf{W}]] +CM_{\textbf{t}}(P) = \mathbb{E}_{\textbf{W}}[\mathbb{E}[Y|\textbf{T}=\textbf{t}, \textbf{W}]] ``` - TMLE.jl Example @@ -70,7 +68,7 @@ statisticalΨ = CM( - Causal Question: -What is the average difference in treatment effect on ``Y`` when the two treatment levels are set to ``\textbf{t}_1`` and ``\textbf{t}_2`` respectively? +What would be the average difference on ``Y`` if we switch the treatment levels from ``\textbf{t}_1`` to ``\textbf{t}_2``? - Causal Estimand: @@ -124,7 +122,7 @@ statisticalΨ = ATE( - Causal Question: -Interactions can be defined up to any order but we restrict the interpretation to two variables. Is the Total Average Treatment Effect of ``T₁`` and ``T₂`` different from the sum of their respective marginal Average Treatment Effects? Is there a synergistic effect between ``T₁`` and ``T₂`` on ``Y``. +Interactions can be defined up to any order but we restrict the interpretation to two variables. Is the Total Average Treatment Effect of ``T₁`` and ``T₂`` different from the sum of their respective marginal Average Treatment Effects? Is there a synergistic additive effect between ``T₁`` and ``T₂`` on ``Y``. For a general higher-order definition, please refer to [Higher-order interactions in statistical physics and machine learning: A model-independent solution to the inverse problem at equilibrium](https://arxiv.org/abs/2006.06010). diff --git a/docs/src/walk_through.md b/docs/src/walk_through.md index fe03cdd7..a478f031 100644 --- a/docs/src/walk_through.md +++ b/docs/src/walk_through.md @@ -58,11 +58,11 @@ Even though the role of a variable (treatment, outcome, confounder, ...) is rela - 1 Outcome variable (``Y``). - 2 Treatment variables ``(T₁, T₂)`` with confounders ``(W₁₁, W₁₂)`` and ``(W₂₁, W₂₂)`` respectively. -- 1 Extra Covariate variable (``C``). +- 1 Outcome extra covariate variable (``C``). ## The Structural Causal Model -The modeling stage starts from the definition of a Structural Causal Model (`SCM`). This is simply a list of Structural Equations (`SE`) describing the relationships between the random variables associated with our problem. See [Structural Causal Models](@ref) for an in-depth explanation. For our purposes, we will simply define it as follows: +The modeling stage starts from the definition of a Structural Causal Model (`SCM`). This is simply a list of relationships between the random variables in our dataset. See [Structural Causal Models](@ref) for an in-depth explanation. For our purposes, because we know the data generating process, we can define it as follows: ```@example walk-through scm = SCM( @@ -74,7 +74,7 @@ scm = SCM( ## The Causal Estimands -From the previous causal model we can ask multiple causal questions which are each represented by distinct causal estimands. The set of available estimands types can be listed as follow: +From the previous causal model we can ask multiple causal questions, all represented by distinct causal estimands. The set of available estimands types can be listed as follow: ```@example walk-through AVAILABLE_ESTIMANDS @@ -131,7 +131,7 @@ Alternatively, you can also directly define the statistical parameters (see [Est ## Estimation -Then each parameter can be estimated by building an estimator (which is simply a function) and calling it on data. For illustration, we will keep the models used simple linear models. We define a TMLE: +Then each parameter can be estimated by building an estimator (which is simply a function) and evaluating it on data. For illustration, we will keep the models simple. We define a Targeted Maximum Likelihood Estimator: ```@example walk-through models = ( @@ -142,14 +142,14 @@ models = ( tmle = TMLEE(models) ``` -And estimate the `cm` causal parameter. Because we haven't identified the parameter, we need to provide the `scm` as well: +Because we haven't identified the `cm` causal estimand yet, we need to provide the `scm` as well to the estimator: ```@example walk-through result, cache = tmle(cm, scm, dataset); result ``` -We can estimate statistical parameters directly, let's use the One-Step estimator: +Statistical Estimands can be estimated without a ``SCM``, let's use the One-Step estimator: ```@example walk-through ose = OSE(models) @@ -159,7 +159,7 @@ result ## Hypothesis Testing -Because the TMLE and OSE are asymptotically linear estimators, they asymptotically follow a Normal distribution. This means one can perform standard T/Z tests of null hypothesis. TMLE.jl extends the method provided by the [HypothesisTests.jl](https://juliastats.org/HypothesisTests.jl/stable/) package that can be used as follows. +Both TMLE and OSE asymptotically follow a Normal distribution. It means we can perform standard T/Z tests of null hypothesis. TMLE.jl extends the method provided by the [HypothesisTests.jl](https://juliastats.org/HypothesisTests.jl/stable/) package that can be used as follows. ```@example walk-through OneSampleTTest(result) diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index 761948a4..b9eaa5ef 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -14,9 +14,6 @@ end CMRelevantFactors(outcome_mean, propensity_score::ConditionalDistribution) = CMRelevantFactors(outcome_mean, (propensity_score,)) -CMRelevantFactors(outcome_mean, propensity_score) = - CMRelevantFactors(outcome_mean, propensity_score) - CMRelevantFactors(;outcome_mean, propensity_score) = CMRelevantFactors(outcome_mean, propensity_score) From 2e1a798589b8e7bdcbe8d45d1b99f9ecbc14dcab Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 30 Oct 2023 21:37:39 +0000 Subject: [PATCH 095/151] try see if docs are built --- docs/make.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 5500fc44..648af96c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -30,14 +30,15 @@ makedocs(; "Walk Through" => "walk_through.md", "User Guide" => [joinpath("user_guide", f) for f in ("scm.md", "estimands.md", "estimation.md", "misc.md")], - "Examples" => [ - # joinpath("examples", "introduction_to_targeted_learning.md"), - # joinpath("examples", "super_learning.md"), - # joinpath("examples", "double_robustness.md") - ], + # "Examples" => [ + # joinpath("examples", "introduction_to_targeted_learning.md"), + # joinpath("examples", "super_learning.md"), + # joinpath("examples", "double_robustness.md") + # ], "Resources" => "resources.md", "API Reference" => "api.md" ], + pagesonly=true ) @info "Deploying docs..." From d2e775dce6d2524ea10841be463d088dfdd0fc4a Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 31 Oct 2023 11:19:59 +0000 Subject: [PATCH 096/151] update docs --- docs/make.jl | 23 ++--- docs/src/user_guide/estimands.md | 3 +- examples/double_robustness.jl | 13 +-- examples/super_learning.jl | 151 ++++++++++++++++--------------- 4 files changed, 101 insertions(+), 89 deletions(-) diff --git a/docs/make.jl b/docs/make.jl index 648af96c..1b5acea6 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -8,11 +8,11 @@ DocMeta.setdocmeta!(TMLE, :DocTestSetup, :(using TMLE); recursive=true) @info "Building Literate pages..." examples_dir = joinpath(@__DIR__, "../examples") build_examples_dir = joinpath(@__DIR__, "src", "examples/") -# for file in readdir(examples_dir) -# if endswith(file, "jl") -# Literate.markdown(joinpath(examples_dir, file), build_examples_dir;documenter=true) -# end -# end +for file in readdir(examples_dir) + if endswith(file, "jl") + Literate.markdown(joinpath(examples_dir, file), build_examples_dir;documenter=true) + end +end @info "Running makedocs..." makedocs(; @@ -30,15 +30,16 @@ makedocs(; "Walk Through" => "walk_through.md", "User Guide" => [joinpath("user_guide", f) for f in ("scm.md", "estimands.md", "estimation.md", "misc.md")], - # "Examples" => [ - # joinpath("examples", "introduction_to_targeted_learning.md"), - # joinpath("examples", "super_learning.md"), - # joinpath("examples", "double_robustness.md") - # ], + "Examples" => [ + joinpath("examples", "super_learning.md"), + joinpath("examples", "double_robustness.md") + ], "Resources" => "resources.md", "API Reference" => "api.md" ], - pagesonly=true + pagesonly=true, + clean = true, + checkdocs=:exports ) @info "Deploying docs..." diff --git a/docs/src/user_guide/estimands.md b/docs/src/user_guide/estimands.md index 28e0ec5e..219b6f7c 100644 --- a/docs/src/user_guide/estimands.md +++ b/docs/src/user_guide/estimands.md @@ -15,6 +15,7 @@ At the moment, most of the work in this package has been focused on estimands th In what follows, ``P`` is a probability distribution generating an outcome ``Y``, a random vector of "treatment" variables ``\textbf{T}`` and a random vector of "confounding" variables ``\textbf{W}``. For the examples, we will assume two treatment variables ``T₁`` and ``T₂`` taking either values 0 or 1. The ``SCM`` is given by: ```@example estimands +using TMLE scm = StaticSCM( [:Y], [:T₁, :T₂], @@ -51,7 +52,7 @@ causalΨ = CM(outcome=:Y, treatment_values=(T₁=1, T₂=0)) A corresponding statistical estimand can be identified via backdoor adjustment using the `scm`: ```@example estimands -statisticalΨ = identify(Ψ, scm) +statisticalΨ = identify(causalΨ, scm) ``` or defined directly: diff --git a/examples/double_robustness.jl b/examples/double_robustness.jl index c0efdf7c..3a1f5952 100644 --- a/examples/double_robustness.jl +++ b/examples/double_robustness.jl @@ -154,13 +154,14 @@ that we now have full coverage of the ground truth. function tmle_inference(data) Ψ = ATE( outcome=:Y, - treatment=(Tcat=(case=1.0, control=0.0),), - confounders=[:W] + treatment_values=(Tcat=(case=1.0, control=0.0),), + treatment_confounders=(Tcat=[:W],) ) - result, _ = tmle!(Ψ, data; verbosity=0) - tmleresult = tmle(result) - lb, ub = confint(OneSampleTTest(tmleresult)) - return (TMLE.estimate(tmleresult), lb, ub) + models = (Y=with_encoder(LinearRegressor()), Tcat=LinearBinaryClassifier()) + tmle = TMLEE(models) + result, _ = tmle(Ψ, data; verbosity=0) + lb, ub = confint(OneSampleTTest(result)) + return (TMLE.estimate(result), lb, ub) end make_animation(tmle_inference) diff --git a/examples/super_learning.jl b/examples/super_learning.jl index a585e556..bc124a32 100644 --- a/examples/super_learning.jl +++ b/examples/super_learning.jl @@ -6,7 +6,7 @@ Super Learning, also known as Stacking, is an ensemble technique that was first introduced by Wolpert in 1992. Instead of selecting a model based on cross-validation performance, models are combined by a meta-learner to minimize the cross-validation error. It has also been shown by van der Laan et al. that the resulting -Super Learner will perform at least as well as its best performing submodel. +Super Learner will perform at least as well as its best performing submodel (at least asymptotically). Why is it important for Targeted Learning? @@ -16,38 +16,19 @@ By only using unrealistic models like linear models, we have little chance of sa Super Learning is a data driven way to leverage a diverse set of models and build the best performing estimator for both ``Q_0`` or ``G_0``. -In the following, we investigate the benefits of Super Learning for the estimation of the Average Treatment Effect. - ## The dataset -For this example we will use the following perinatal dataset. The (haz01, parity01) are converted to -categorical values. -=# -using CSV -using DataFrames -using TMLE -using MLJ - -dataset = CSV.read( - joinpath(pkgdir(TMLE), "test", "data", "perinatal.csv"), - DataFrame, - select=[:haz01, :parity01, :apgar1, :apgar5, :gagebrth, :mage, :meducyrs, :sexn], - types=Float64 -) -dataset.haz01 = categorical(Int.(dataset.haz01)) -dataset.parity01 = categorical(Int.(dataset.parity01)) -nothing # hide - -#= +Let's consider the case where Y is categorical. In TMLE.jl, this could be useful to learn: -We will also assume the following causal model: +- The propensity score +- The outcome model when the outcome is binary +We will use the following moons dataset: =# +using MLJ -scm = SCM( - SE(:haz01, [:parity01, :apgar1, :apgar5, :gagebrth, :mage, :meducyrs, :sexn]), - SE(:parity01, [:apgar1, :apgar5, :gagebrth, :mage, :meducyrs, :sexn]) -) +X, y = MLJ.make_moons(1000) +nothing # hide #= @@ -55,83 +36,111 @@ scm = SCM( In MLJ, a Super Learner can be defined using the [Stack](https://alan-turing-institute.github.io/MLJ.jl/stable/model_stacking/) function. The three most important type of arguments for a Stack are: -- `metalearner`: The metalearner to be used to combine the weak learner to be defined. -- `resampling`: The cross-validation scheme, by default, a 6-fold cross-validation. +- `metalearner`: The metalearner to be used to combine the weak learner to be defined. Typically a generalized linear model. +- `resampling`: The cross-validation scheme, by default, a 6-fold cross-validation. Since we are working with categorical +data it is a good idea to make sure the splits are balanced. We will thus use a `StratifiedCV` resampling strategy. - `models...`: A series of named MLJ models. -One important point is that MLJ does not provide any model by itself, those have to be loaded from +One important point is that MLJ does not provide any model by itself, juat the API, models have to be loaded from external compatible libraries. You can search for available models that match your data. -In our case, for both ``G_0`` and ``Q_0`` we need classification models and we can see there are quire a few of them: =# -G_available_models = models(matching(dataset[!, parents(scm.parity01)], dataset.parity01)) +models(matching(X, y)) #= -!!! note "Stack limitations" - For now, there are a few limitations as to which models you can actually use within the Stack. - The most important is that if the output is categorical, each model must be `<: Probabilistic`, - which means that SVMs cannot be used as a weak learners for classification. +!!! note "Stack limitation" + The Stack cannot contain `<:Deterministic` models for classification. -Let's load a few model providing libraries and define our library for ``G_0``. +Let's load a few packages providing models and build our first Stack: =# -using EvoTrees +using MLJXGBoostInterface using MLJLinearModels -using MLJModels using NearestNeighborModels -function superlearner_models() - lambdas = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 0, 1., 10., 100.] - logistic_models = [LogisticClassifier(lambda=l) for l in lambdas] - logistic_models = NamedTuple{Tuple(Symbol("lr_$i") for i in eachindex(lambdas))}(logistic_models) - evo_trees = [EvoTreeClassifier(lambda=l) for l in lambdas] - evo_trees = NamedTuple{Tuple(Symbol("tree_$i") for i in eachindex(lambdas))}(evo_trees) - Ks = [5, 10, 50, 100] - knns = [KNNClassifier(K=k) for k in Ks] - knns = NamedTuple{Tuple(Symbol("knn_$i") for i in eachindex(Ks))}(knns) - return merge(logistic_models, evo_trees, knns) -end - -superlearner = Stack(; - metalearner = LogisticClassifier(lambda=0), - resampling = StratifiedCV(nfolds=3), - measure = log_loss, - superlearner_models()... -) +resampling = StratifiedCV() +metalearner = LogisticClassifier() -nothing # hide +stack = Stack( + metalearner = metalearner, + resampling = resampling, + lr = LogisticClassifier(), + knn = KNNClassifier(K=3) +) #= +This Stack only contains 2 different models: a logistic classifier and a KNN classifier. +A Stack is just like any MLJ model, it can be wrapped in a `machine` and fitted: +=# -and assign those models to the `SCM`: +mach = machine(stack, X, y) +fit!(mach, verbosity=0) +#= +Or evaluated. Because the Stack contains a cross-validation procedure, this will result in two nested levels of resampling. =# -setmodel!(scm.haz01, with_encoder(superlearner)) -setmodel!(scm.parity01, superlearner) +evaluate!(mach, measure=log_loss, resampling=resampling) #= -## Targeted estimation +## A more advanced Stack + +What are good Stack members? Virtually anything, provided they are MLJ models. Here are a few examples: + +- You can use the stack to "select" model hyper-parameters. e.g. `KNNClassifier(K=3)` or `KNNClassifier(K=2)`? +- You can also use [self-tuning models](https://alan-turing-institute.github.io/MLJ.jl/stable/tuning_models/). Note that because +these models resort to cross-validation, fitting the stack will result in two nested +levels of sample-splitting. -Let us move to the targeted estimation step itself. We define the target estimand (the ATE): +The following self-tuned XGBoost will vary some hyperparameters in an internal sample-splitting procedure in order to optimize the Log-Loss. +It will then be combined with the rest of the models in the Stack's own sample-splitting procedure. Finally, evaluation is performed in an outer sample-split. =# -Ψ = ATE( - scm, - outcome=:haz01, - treatment=(parity01=(case=true, control=false),), + +xgboost = XGBoostClassifier(tree_method="hist") +self_tuning_xgboost = TunedModel( + model = xgboost, + resampling = resampling, + tuning = Grid(goal=20), + range = [ + range(xgboost, :max_depth, lower=3, upper=7), + range(xgboost, :lambda, lower=1e-5, upper=10, scale=:log) + ], + measure = log_loss, ) +stack = Stack( + metalearner = metalearner, + resampling = resampling, + self_tuning_xgboost = self_tuning_xgboost, + lr = LogisticClassifier(), + knn_2 = KNNClassifier(K=2), + knn_3 = KNNClassifier(K=3) +) -nothing # hide +mach = machine(stack, X, y) +evaluate!(mach, measure=log_loss, resampling=resampling) #= -Finally run the TMLE procedure: + +## Diagnostic + +Optionally, one can also investigate how sucessful the weak learners were in the Stack's internal +cross-validation. This is done by specifying the `measures` keyword argument. + +Here we look at both the Log-Loss and the AUC. + =# -tmle_result, _ = tmle!(Ψ, dataset) -tmle_result +stack.measures = [log_loss, auc] +fit!(mach, verbosity=0) +report(mach).cv_report +#= + +One can look at the fitted parameters for the metalearner as well: +=# +fitted_params(mach).metalearner From a761ea7c033dce75804d23ecb5db6057932b84f4 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 31 Oct 2023 11:23:25 +0000 Subject: [PATCH 097/151] fix project --- docs/Project.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/Project.toml b/docs/Project.toml index e7c9a2e9..1ffd37c6 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -12,14 +12,17 @@ MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c" MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692" MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" +MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SQLite = "0aa819cd-b072-5ff4-a722-6bc24af294d9" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +TMLE = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] +MLJXGBoostInterface = "0.3.8" CSV = "0.10.11" CairoMakie = "0.10.4" DataFrames = "1.5" From 2061cf63a47529ed17fa84a21b78c8234364458b Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 31 Oct 2023 12:15:27 +0000 Subject: [PATCH 098/151] add show method for scm --- examples/super_learning.jl | 6 ++-- src/counterfactual_mean_based/adjustment.jl | 4 +-- src/scm.jl | 37 ++++++++++++++++----- 3 files changed, 35 insertions(+), 12 deletions(-) diff --git a/examples/super_learning.jl b/examples/super_learning.jl index bc124a32..7009bb20 100644 --- a/examples/super_learning.jl +++ b/examples/super_learning.jl @@ -108,6 +108,7 @@ self_tuning_xgboost = TunedModel( range(xgboost, :lambda, lower=1e-5, upper=10, scale=:log) ], measure = log_loss, + cache=false ) stack = Stack( @@ -116,10 +117,11 @@ stack = Stack( self_tuning_xgboost = self_tuning_xgboost, lr = LogisticClassifier(), knn_2 = KNNClassifier(K=2), - knn_3 = KNNClassifier(K=3) + knn_3 = KNNClassifier(K=3), + cache = false ) -mach = machine(stack, X, y) +mach = machine(stack, X, y, cache=false) evaluate!(mach, measure=log_loss, resampling=resampling) #= diff --git a/src/counterfactual_mean_based/adjustment.jl b/src/counterfactual_mean_based/adjustment.jl index 9fbca3b4..b972cca1 100644 --- a/src/counterfactual_mean_based/adjustment.jl +++ b/src/counterfactual_mean_based/adjustment.jl @@ -18,11 +18,11 @@ function statistical_type_from_causal_type(T) return eval(Symbol(new_typestring)) end -identify(causal_estimand::T, scm::MetaGraph; method=BackdoorAdjustment()::BackdoorAdjustment) where T<:CausalCMCompositeEstimands = +identify(causal_estimand::T, scm::SCM; method=BackdoorAdjustment()::BackdoorAdjustment) where T<:CausalCMCompositeEstimands = identify(method, causal_estimand, scm) -function identify(method::BackdoorAdjustment, causal_estimand::T, scm::MetaGraph) where T<:CausalCMCompositeEstimands +function identify(method::BackdoorAdjustment, causal_estimand::T, scm::SCM) where T<:CausalCMCompositeEstimands # Treatment confounders treatment_names = keys(causal_estimand.treatment_values) treatment_codes = [code_for(scm, treatment) for treatment ∈ treatment_names] diff --git a/src/scm.jl b/src/scm.jl index 4487ea53..9820529d 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -3,21 +3,42 @@ ##################################################################### """ -A SCM is simply a MetaGraph over a Directed Acyclic Graph with additional methods. +A SCM is simply a wrapper around a MetaGraph over a Directed Acyclic Graph. """ -const SCM = MetaGraph +struct SCM + graph::MetaGraph +end function SCM(equations...) - scm = MetaGraph( + graph = MetaGraph( SimpleDiGraph(); label_type=Symbol, vertex_data_type=Nothing, edge_data_type=Nothing ) + scm = SCM(graph) add_equations!(scm, equations...) return scm end +function string_repr(scm::SCM) + string_rep = "SCM\n---" + for (vertex_id, vertex_label) in scm.graph.vertex_labels + vertex_parents = parents(scm, vertex_label) + if length(vertex_parents) > 0 + digits = split(string(vertex_id), "") + subscript = join(Meta.parse("'\\U0208$digit'") for digit in digits) + eq_string = string("\n", vertex_label, " = f", subscript, "(", join(vertex_parents, ", "), ")") + string_rep = string(string_rep, eq_string) + end + end + return string_rep +end + +Base.show(io::IO, ::MIME"text/plain", scm::SCM) = + println(io, string_repr(scm)) + + function add_equations!(scm::SCM, equations...) for outcome_parents_pair in equations add_equation!(scm, outcome_parents_pair) @@ -27,17 +48,17 @@ end function add_equation!(scm::SCM, outcome_parents_pair) outcome, parents = outcome_parents_pair outcome_symbol = Symbol(outcome) - add_vertex!(scm, outcome_symbol) + add_vertex!(scm.graph, outcome_symbol) for parent in parents parent_symbol = Symbol(parent) - add_vertex!(scm, parent_symbol) - add_edge!(scm, parent_symbol, outcome_symbol) + add_vertex!(scm.graph, parent_symbol) + add_edge!(scm.graph, parent_symbol, outcome_symbol) end end function parents(scm, label) - code, _ = scm.vertex_properties[label] - return [scm.vertex_labels[parent_code] for parent_code in scm.graph.badjlist[code]] + code, _ = scm.graph.vertex_properties[label] + return [scm.graph.vertex_labels[parent_code] for parent_code in scm.graph.graph.badjlist[code]] end """ From 29cea3eeb87614c2c9a026ae710a41b675261e74 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 31 Oct 2023 13:15:31 +0000 Subject: [PATCH 099/151] fix code --- src/counterfactual_mean_based/adjustment.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/counterfactual_mean_based/adjustment.jl b/src/counterfactual_mean_based/adjustment.jl index b972cca1..14e11400 100644 --- a/src/counterfactual_mean_based/adjustment.jl +++ b/src/counterfactual_mean_based/adjustment.jl @@ -25,8 +25,8 @@ identify(causal_estimand::T, scm::SCM; method=BackdoorAdjustment()::BackdoorAdju function identify(method::BackdoorAdjustment, causal_estimand::T, scm::SCM) where T<:CausalCMCompositeEstimands # Treatment confounders treatment_names = keys(causal_estimand.treatment_values) - treatment_codes = [code_for(scm, treatment) for treatment ∈ treatment_names] - confounders_codes = scm.graph.badjlist[treatment_codes] + treatment_codes = [code_for(scm.graph, treatment) for treatment ∈ treatment_names] + confounders_codes = scm.graph.graph.badjlist[treatment_codes] treatment_confounders = NamedTuple{treatment_names}( [[scm.vertex_labels[w] for w in confounders_codes[i]] for i in eachindex(confounders_codes)] From 6e50645f8fa649a3aeff38ad826212cc563ac64a Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 31 Oct 2023 14:01:03 +0000 Subject: [PATCH 100/151] fix code --- src/counterfactual_mean_based/adjustment.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/counterfactual_mean_based/adjustment.jl b/src/counterfactual_mean_based/adjustment.jl index 14e11400..88b0bec8 100644 --- a/src/counterfactual_mean_based/adjustment.jl +++ b/src/counterfactual_mean_based/adjustment.jl @@ -28,7 +28,7 @@ function identify(method::BackdoorAdjustment, causal_estimand::T, scm::SCM) wher treatment_codes = [code_for(scm.graph, treatment) for treatment ∈ treatment_names] confounders_codes = scm.graph.graph.badjlist[treatment_codes] treatment_confounders = NamedTuple{treatment_names}( - [[scm.vertex_labels[w] for w in confounders_codes[i]] + [[scm.graph.vertex_labels[w] for w in confounders_codes[i]] for i in eachindex(confounders_codes)] ) From 1677c9125e7fef2c5e967cc011f4f30fe8bdbdb7 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 1 Nov 2023 11:31:15 +0000 Subject: [PATCH 101/151] add configurations back in --- .gitignore | 3 +- Project.toml | 6 +- docs/Manifest.toml | 2055 ----------------- docs/Project.toml | 6 +- src/TMLE.jl | 5 +- .../configurations.jl | 37 + src/counterfactual_mean_based/estimands.jl | 4 +- .../configurations.jl | 53 + test/runtests.jl | 1 + 9 files changed, 103 insertions(+), 2067 deletions(-) delete mode 100644 docs/Manifest.toml create mode 100644 src/counterfactual_mean_based/configurations.jl create mode 100644 test/counterfactual_mean_based/configurations.jl diff --git a/.gitignore b/.gitignore index 59c39846..d3330493 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ build/ sandbox.jl /Manifest.toml -/test/Manifest.toml \ No newline at end of file +/test/Manifest.toml +/docs/Manifest.toml \ No newline at end of file diff --git a/Project.toml b/Project.toml index 17953fd0..5be2577f 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.12.0" [deps] AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" +Configurations = "5218b696-f38b-4ac9-8b61-a12ec717816d" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -30,11 +31,13 @@ AbstractDifferentiation = "0.4, 0.5" CategoricalArrays = "0.10" Distributions = "0.25" GLM = "1.8.2" +Graphs = "1.8" HypothesisTests = "0.10, 0.11" LogExpFunctions = "0.3" MLJBase = "0.19, 0.20, 0.21" MLJGLMInterface = "0.3.4" MLJModels = "0.15, 0.16" +MetaGraphsNext = "0.6" Missings = "1.0" PrecompileTools = "1.1.1" PrettyTables = "2.2" @@ -42,6 +45,5 @@ TableOperations = "1.2" Tables = "1.6" YAML = "0.4" Zygote = "0.6" -Graphs = "1.8" -MetaGraphsNext = "0.6" +Configurations = "0.17.6" julia = "1.6, 1.7, 1" diff --git a/docs/Manifest.toml b/docs/Manifest.toml deleted file mode 100644 index 39087c80..00000000 --- a/docs/Manifest.toml +++ /dev/null @@ -1,2055 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -julia_version = "1.9.2" -manifest_format = "2.0" -project_hash = "0251892739390e1f6b399b7b0a8ba56b014a5eee" - -[[deps.ANSIColoredPrinters]] -git-tree-sha1 = "574baf8110975760d391c710b6341da1afa48d8c" -uuid = "a4c015fc-c6ff-483c-b24f-f7ea428134e9" -version = "0.0.1" - -[[deps.ARFFFiles]] -deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] -git-tree-sha1 = "e8c8e0a2be6eb4f56b1672e46004463033daa409" -uuid = "da404889-ca92-49ff-9e8b-0aa6b4d38dc8" -version = "1.4.1" - -[[deps.AbstractFFTs]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "d92ad398961a3ed262d8bf04a1a2b8340f915fef" -uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c" -version = "1.5.0" -weakdeps = ["ChainRulesCore", "Test"] - - [deps.AbstractFFTs.extensions] - AbstractFFTsChainRulesCoreExt = "ChainRulesCore" - AbstractFFTsTestExt = "Test" - -[[deps.AbstractLattices]] -git-tree-sha1 = "f35684b7349da49fcc8a9e520e30e45dbb077166" -uuid = "398f06c4-4d28-53ec-89ca-5b2656b7603d" -version = "0.2.1" - -[[deps.AbstractTrees]] -git-tree-sha1 = "faa260e4cb5aba097a73fab382dd4b5819d8ec8c" -uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" -version = "0.4.4" - -[[deps.Adapt]] -deps = ["LinearAlgebra", "Requires"] -git-tree-sha1 = "76289dc51920fdc6e0013c872ba9551d54961c24" -uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.6.2" -weakdeps = ["StaticArrays"] - - [deps.Adapt.extensions] - AdaptStaticArraysExt = "StaticArrays" - -[[deps.Animations]] -deps = ["Colors"] -git-tree-sha1 = "e81c509d2c8e49592413bfb0bb3b08150056c79d" -uuid = "27a7e980-b3e6-11e9-2bcd-0b925532e340" -version = "0.4.1" - -[[deps.ArgTools]] -uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" -version = "1.1.1" - -[[deps.ArrayInterface]] -deps = ["Adapt", "LinearAlgebra", "Requires", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "f83ec24f76d4c8f525099b2ac475fc098138ec31" -uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.4.11" - - [deps.ArrayInterface.extensions] - ArrayInterfaceBandedMatricesExt = "BandedMatrices" - ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" - ArrayInterfaceCUDAExt = "CUDA" - ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" - ArrayInterfaceStaticArraysCoreExt = "StaticArraysCore" - ArrayInterfaceTrackerExt = "Tracker" - - [deps.ArrayInterface.weakdeps] - BandedMatrices = "aae01518-5342-5314-be14-df237901396f" - BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" - CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" - GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" - StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" - Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" - -[[deps.Artifacts]] -uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" - -[[deps.Atomix]] -deps = ["UnsafeAtomics"] -git-tree-sha1 = "c06a868224ecba914baa6942988e2f2aade419be" -uuid = "a9b6321e-bd34-4604-b9c9-b65b8de01458" -version = "0.1.0" - -[[deps.Automa]] -deps = ["TranscodingStreams"] -git-tree-sha1 = "ef9997b3d5547c48b41c7bd8899e812a917b409d" -uuid = "67c07d97-cdcb-5c2c-af73-a7f9c32a568b" -version = "0.8.4" - -[[deps.AxisAlgorithms]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "WoodburyMatrices"] -git-tree-sha1 = "66771c8d21c8ff5e3a93379480a2307ac36863f7" -uuid = "13072b0f-2c55-5437-9ae7-d433b7a33950" -version = "1.0.1" - -[[deps.AxisArrays]] -deps = ["Dates", "IntervalSets", "IterTools", "RangeArrays"] -git-tree-sha1 = "16351be62963a67ac4083f748fdb3cca58bfd52f" -uuid = "39de3d68-74b9-583c-8d2d-e117c070f3a9" -version = "0.4.7" - -[[deps.BFloat16s]] -deps = ["LinearAlgebra", "Printf", "Random", "Test"] -git-tree-sha1 = "dbf84058d0a8cbbadee18d25cf606934b22d7c66" -uuid = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" -version = "0.4.2" - -[[deps.BSON]] -git-tree-sha1 = "2208958832d6e1b59e49f53697483a84ca8d664e" -uuid = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" -version = "0.3.7" - -[[deps.Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[deps.BitFlags]] -git-tree-sha1 = "43b1a4a8f797c1cddadf60499a8a077d4af2cd2d" -uuid = "d1d4a3ce-64b1-5f1a-9ba4-7e7e69966f35" -version = "0.1.7" - -[[deps.Bzip2_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "19a35467a82e236ff51bc17a3a44b69ef35185a2" -uuid = "6e34b625-4abd-537c-b88f-471c36dfa7a0" -version = "1.0.8+0" - -[[deps.CEnum]] -git-tree-sha1 = "eb4cb44a499229b3b8426dcfb5dd85333951ff90" -uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82" -version = "0.4.2" - -[[deps.CRC32c]] -uuid = "8bf52ea8-c179-5cab-976a-9e18b702a9bc" - -[[deps.CRlibm]] -deps = ["CRlibm_jll"] -git-tree-sha1 = "32abd86e3c2025db5172aa182b982debed519834" -uuid = "96374032-68de-5a5b-8d9e-752f78720389" -version = "1.0.1" - -[[deps.CRlibm_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "e329286945d0cfc04456972ea732551869af1cfc" -uuid = "4e9b3aee-d8a1-5a3d-ad8b-7d824db253f0" -version = "1.0.1+0" - -[[deps.CSV]] -deps = ["CodecZlib", "Dates", "FilePathsBase", "InlineStrings", "Mmap", "Parsers", "PooledArrays", "PrecompileTools", "SentinelArrays", "Tables", "Unicode", "WeakRefStrings", "WorkerUtilities"] -git-tree-sha1 = "44dbf560808d49041989b8a96cae4cffbeb7966a" -uuid = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -version = "0.10.11" - -[[deps.CUDA]] -deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CUDA_Driver_jll", "CUDA_Runtime_Discovery", "CUDA_Runtime_jll", "ExprTools", "GPUArrays", "GPUCompiler", "KernelAbstractions", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Preferences", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "35160ef0f03b14768abfd68b830f8e3940e8e0dc" -uuid = "052768ef-5323-5732-b1bb-66c8b64840ba" -version = "4.4.0" - -[[deps.CUDA_Driver_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "498f45593f6ddc0adff64a9310bb6710e851781b" -uuid = "4ee394cb-3365-5eb0-8335-949819d2adfc" -version = "0.5.0+1" - -[[deps.CUDA_Runtime_Discovery]] -deps = ["Libdl"] -git-tree-sha1 = "bcc4a23cbbd99c8535a5318455dcf0f2546ec536" -uuid = "1af6417a-86b4-443c-805f-a4643ffb695f" -version = "0.2.2" - -[[deps.CUDA_Runtime_jll]] -deps = ["Artifacts", "CUDA_Driver_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "5248d9c45712e51e27ba9b30eebec65658c6ce29" -uuid = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" -version = "0.6.0+0" - -[[deps.Cairo]] -deps = ["Cairo_jll", "Colors", "Glib_jll", "Graphics", "Libdl", "Pango_jll"] -git-tree-sha1 = "d0b3f8b4ad16cb0a2988c6788646a5e6a17b6b1b" -uuid = "159f3aea-2a34-519c-b102-8c37f9878175" -version = "1.0.5" - -[[deps.CairoMakie]] -deps = ["Base64", "Cairo", "Colors", "FFTW", "FileIO", "FreeType", "GeometryBasics", "LinearAlgebra", "Makie", "PrecompileTools", "SHA"] -git-tree-sha1 = "e041782fed7614b1726fa250f2bf24fd5c789689" -uuid = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -version = "0.10.7" - -[[deps.Cairo_jll]] -deps = ["Artifacts", "Bzip2_jll", "CompilerSupportLibraries_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "JLLWrappers", "LZO_jll", "Libdl", "Pixman_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libXrender_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "4b859a208b2397a7a623a03449e4636bdb17bcf2" -uuid = "83423d85-b0ee-5818-9007-b63ccbeb887a" -version = "1.16.1+1" - -[[deps.Calculus]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "f641eb0a4f00c343bbc32346e1217b86f3ce9dad" -uuid = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" -version = "0.5.1" - -[[deps.CategoricalArrays]] -deps = ["DataAPI", "Future", "Missings", "Printf", "Requires", "Statistics", "Unicode"] -git-tree-sha1 = "1568b28f91293458345dabba6a5ea3f183250a61" -uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597" -version = "0.10.8" - - [deps.CategoricalArrays.extensions] - CategoricalArraysJSONExt = "JSON" - CategoricalArraysRecipesBaseExt = "RecipesBase" - CategoricalArraysSentinelArraysExt = "SentinelArrays" - CategoricalArraysStructTypesExt = "StructTypes" - - [deps.CategoricalArrays.weakdeps] - JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" - RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" - SentinelArrays = "91c51154-3ec4-41a3-a24f-3f23e20d615c" - StructTypes = "856f2bd8-1eba-4b0a-8007-ebc267875bd4" - -[[deps.CategoricalDistributions]] -deps = ["CategoricalArrays", "Distributions", "Missings", "OrderedCollections", "Random", "ScientificTypes"] -git-tree-sha1 = "da68989f027dcefa74d44a452c9e36af9730a70d" -uuid = "af321ab8-2d2e-40a6-b165-3d674595d28e" -version = "0.1.10" - - [deps.CategoricalDistributions.extensions] - UnivariateFiniteDisplayExt = "UnicodePlots" - - [deps.CategoricalDistributions.weakdeps] - UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" - -[[deps.ChainRulesCore]] -deps = ["Compat", "LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "e30f2f4e20f7f186dc36529910beaedc60cfa644" -uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "1.16.0" - -[[deps.CodecZlib]] -deps = ["TranscodingStreams", "Zlib_jll"] -git-tree-sha1 = "02aa26a4cf76381be7f66e020a3eddeb27b0a092" -uuid = "944b1d66-785c-5afd-91f1-9de20f533193" -version = "0.7.2" - -[[deps.ColorBrewer]] -deps = ["Colors", "JSON", "Test"] -git-tree-sha1 = "61c5334f33d91e570e1d0c3eb5465835242582c4" -uuid = "a2cac450-b92f-5266-8821-25eda20663c8" -version = "0.4.0" - -[[deps.ColorSchemes]] -deps = ["ColorTypes", "ColorVectorSpace", "Colors", "FixedPointNumbers", "PrecompileTools", "Random"] -git-tree-sha1 = "d9a8f86737b665e15a9641ecbac64deef9ce6724" -uuid = "35d6a980-a343-548e-a6ea-1d62b119f2f4" -version = "3.23.0" - -[[deps.ColorTypes]] -deps = ["FixedPointNumbers", "Random"] -git-tree-sha1 = "eb7f0f8307f71fac7c606984ea5fb2817275d6e4" -uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f" -version = "0.11.4" - -[[deps.ColorVectorSpace]] -deps = ["ColorTypes", "FixedPointNumbers", "LinearAlgebra", "SpecialFunctions", "Statistics", "TensorCore"] -git-tree-sha1 = "600cc5508d66b78aae350f7accdb58763ac18589" -uuid = "c3611d14-8923-5661-9e6a-0046d554d3a4" -version = "0.9.10" - -[[deps.Colors]] -deps = ["ColorTypes", "FixedPointNumbers", "Reexport"] -git-tree-sha1 = "fc08e5930ee9a4e03f84bfb5211cb54e7769758a" -uuid = "5ae59095-9a9b-59fe-a467-6f913c188581" -version = "0.12.10" - -[[deps.Combinatorics]] -git-tree-sha1 = "08c8b6831dc00bfea825826be0bc8336fc369860" -uuid = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" -version = "1.0.2" - -[[deps.CommonSubexpressions]] -deps = ["MacroTools", "Test"] -git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7" -uuid = "bbf7d656-a473-5ed7-a52c-81e309532950" -version = "0.3.0" - -[[deps.Compat]] -deps = ["UUIDs"] -git-tree-sha1 = "e460f044ca8b99be31d35fe54fc33a5c33dd8ed7" -uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" -version = "4.9.0" -weakdeps = ["Dates", "LinearAlgebra"] - - [deps.Compat.extensions] - CompatLinearAlgebraExt = "LinearAlgebra" - -[[deps.CompilerSupportLibraries_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" -version = "1.0.5+0" - -[[deps.ComputationalResources]] -git-tree-sha1 = "52cb3ec90e8a8bea0e62e275ba577ad0f74821f7" -uuid = "ed09eef8-17a6-5b46-8889-db040fac31e3" -version = "0.3.2" - -[[deps.ConcurrentUtilities]] -deps = ["Serialization", "Sockets"] -git-tree-sha1 = "5372dbbf8f0bdb8c700db5367132925c0771ef7e" -uuid = "f0e56b4a-5159-44fe-b623-3e5288b988bb" -version = "2.2.1" - -[[deps.ConstructionBase]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "fe2838a593b5f776e1597e086dcd47560d94e816" -uuid = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -version = "1.5.3" -weakdeps = ["IntervalSets", "StaticArrays"] - - [deps.ConstructionBase.extensions] - ConstructionBaseIntervalSetsExt = "IntervalSets" - ConstructionBaseStaticArraysExt = "StaticArrays" - -[[deps.Contour]] -git-tree-sha1 = "d05d9e7b7aedff4e5b51a029dced05cfb6125781" -uuid = "d38c429a-6771-53c6-b99e-75d170b6e991" -version = "0.6.2" - -[[deps.Crayons]] -git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" -uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" -version = "4.1.1" - -[[deps.DBInterface]] -git-tree-sha1 = "9b0dc525a052b9269ccc5f7f04d5b3639c65bca5" -uuid = "a10d1c49-ce27-4219-8d33-6db1a4562965" -version = "2.5.0" - -[[deps.DataAPI]] -git-tree-sha1 = "8da84edb865b0b5b0100c0666a9bc9a0b71c553c" -uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.15.0" - -[[deps.DataFrames]] -deps = ["Compat", "DataAPI", "DataStructures", "Future", "InlineStrings", "InvertedIndices", "IteratorInterfaceExtensions", "LinearAlgebra", "Markdown", "Missings", "PooledArrays", "PrecompileTools", "PrettyTables", "Printf", "REPL", "Random", "Reexport", "SentinelArrays", "SortingAlgorithms", "Statistics", "TableTraits", "Tables", "Unicode"] -git-tree-sha1 = "04c738083f29f86e62c8afc341f0967d8717bdb8" -uuid = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -version = "1.6.1" - -[[deps.DataStructures]] -deps = ["Compat", "InteractiveUtils", "OrderedCollections"] -git-tree-sha1 = "3dbd312d370723b6bb43ba9d02fc36abade4518d" -uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" -version = "0.18.15" - -[[deps.DataValueInterfaces]] -git-tree-sha1 = "bfc1187b79289637fa0ef6d4436ebdfe6905cbd6" -uuid = "e2d170a0-9d28-54be-80f0-106bbe20a464" -version = "1.0.0" - -[[deps.Dates]] -deps = ["Printf"] -uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" - -[[deps.DelaunayTriangulation]] -deps = ["DataStructures", "EnumX", "ExactPredicates", "Random", "SimpleGraphs"] -git-tree-sha1 = "a1d8532de83f8ce964235eff1edeff9581144d02" -uuid = "927a84f5-c5f4-47a5-9785-b46e178433df" -version = "0.7.2" -weakdeps = ["MakieCore"] - - [deps.DelaunayTriangulation.extensions] - DelaunayTriangulationMakieCoreExt = "MakieCore" - -[[deps.DelimitedFiles]] -deps = ["Mmap"] -git-tree-sha1 = "9e2f36d3c96a820c678f2f1f1782582fcf685bae" -uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab" -version = "1.9.1" - -[[deps.DiffResults]] -deps = ["StaticArraysCore"] -git-tree-sha1 = "782dd5f4561f5d267313f23853baaaa4c52ea621" -uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" -version = "1.1.0" - -[[deps.DiffRules]] -deps = ["IrrationalConstants", "LogExpFunctions", "NaNMath", "Random", "SpecialFunctions"] -git-tree-sha1 = "23163d55f885173722d1e4cf0f6110cdbaf7e272" -uuid = "b552c78f-8df3-52c6-915a-8e097449b14b" -version = "1.15.1" - -[[deps.Distances]] -deps = ["LinearAlgebra", "Statistics", "StatsAPI"] -git-tree-sha1 = "b6def76ffad15143924a2199f72a5cd883a2e8a9" -uuid = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" -version = "0.10.9" -weakdeps = ["SparseArrays"] - - [deps.Distances.extensions] - DistancesSparseArraysExt = "SparseArrays" - -[[deps.Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - -[[deps.Distributions]] -deps = ["FillArrays", "LinearAlgebra", "PDMats", "Printf", "QuadGK", "Random", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "Test"] -git-tree-sha1 = "938fe2981db009f531b6332e31c58e9584a2f9bd" -uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" -version = "0.25.100" - - [deps.Distributions.extensions] - DistributionsChainRulesCoreExt = "ChainRulesCore" - DistributionsDensityInterfaceExt = "DensityInterface" - - [deps.Distributions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" - -[[deps.DocStringExtensions]] -deps = ["LibGit2"] -git-tree-sha1 = "2fb1e02f2b635d0845df5d7c167fec4dd739b00d" -uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.9.3" - -[[deps.Documenter]] -deps = ["ANSIColoredPrinters", "Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "39fd748a73dce4c05a9655475e437170d8fb1b67" -uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.27.25" - -[[deps.Downloads]] -deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] -uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" -version = "1.6.0" - -[[deps.DualNumbers]] -deps = ["Calculus", "NaNMath", "SpecialFunctions"] -git-tree-sha1 = "5837a837389fccf076445fce071c8ddaea35a566" -uuid = "fa6b7ba4-c1ee-5f82-b5fc-ecf0adba8f74" -version = "0.6.8" - -[[deps.EarCut_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "e3290f2d49e661fbd94046d7e3726ffcb2d41053" -uuid = "5ae413db-bbd1-5e63-b57d-d24a61df00f5" -version = "2.2.4+0" - -[[deps.EarlyStopping]] -deps = ["Dates", "Statistics"] -git-tree-sha1 = "98fdf08b707aaf69f524a6cd0a67858cefe0cfb6" -uuid = "792122b4-ca99-40de-a6bc-6742525f08b6" -version = "0.3.0" - -[[deps.EnumX]] -git-tree-sha1 = "bdb1942cd4c45e3c678fd11569d5cccd80976237" -uuid = "4e289a0a-7415-4d19-859d-a7e5c4648b56" -version = "1.0.4" - -[[deps.ErrorfreeArithmetic]] -git-tree-sha1 = "d6863c556f1142a061532e79f611aa46be201686" -uuid = "90fa49ef-747e-5e6f-a989-263ba693cf1a" -version = "0.5.2" - -[[deps.EvoTrees]] -deps = ["BSON", "CUDA", "CategoricalArrays", "Distributions", "MLJModelInterface", "NetworkLayout", "Random", "RecipesBase", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "5023442c1f797c0fd6677b1a1886ab44f43f3378" -uuid = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" -version = "0.16.0" - -[[deps.ExactPredicates]] -deps = ["IntervalArithmetic", "Random", "StaticArraysCore", "Test"] -git-tree-sha1 = "276e83bc8b21589b79303b9985c321024ffdf59c" -uuid = "429591f6-91af-11e9-00e2-59fbe8cec110" -version = "2.2.5" - -[[deps.ExceptionUnwrapping]] -deps = ["Test"] -git-tree-sha1 = "e90caa41f5a86296e014e148ee061bd6c3edec96" -uuid = "460bff9d-24e4-43bc-9d9f-a8973cb893f4" -version = "0.1.9" - -[[deps.Expat_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "4558ab818dcceaab612d1bb8c19cee87eda2b83c" -uuid = "2e619515-83b5-522b-bb60-26c02a35a201" -version = "2.5.0+0" - -[[deps.ExprTools]] -git-tree-sha1 = "27415f162e6028e81c72b82ef756bf321213b6ec" -uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04" -version = "0.1.10" - -[[deps.Extents]] -git-tree-sha1 = "5e1e4c53fa39afe63a7d356e30452249365fba99" -uuid = "411431e0-e8b7-467b-b5e0-f676ba4f2910" -version = "0.1.1" - -[[deps.FFMPEG]] -deps = ["FFMPEG_jll"] -git-tree-sha1 = "b57e3acbe22f8484b4b5ff66a7499717fe1a9cc8" -uuid = "c87230d0-a227-11e9-1b43-d7ebe4e7570a" -version = "0.4.1" - -[[deps.FFMPEG_jll]] -deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "JLLWrappers", "LAME_jll", "Libdl", "Ogg_jll", "OpenSSL_jll", "Opus_jll", "PCRE2_jll", "Zlib_jll", "libaom_jll", "libass_jll", "libfdk_aac_jll", "libvorbis_jll", "x264_jll", "x265_jll"] -git-tree-sha1 = "466d45dc38e15794ec7d5d63ec03d776a9aff36e" -uuid = "b22a6f82-2f65-5046-a5b2-351ab43fb4e5" -version = "4.4.4+1" - -[[deps.FFTW]] -deps = ["AbstractFFTs", "FFTW_jll", "LinearAlgebra", "MKL_jll", "Preferences", "Reexport"] -git-tree-sha1 = "b4fbdd20c889804969571cc589900803edda16b7" -uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" -version = "1.7.1" - -[[deps.FFTW_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "c6033cc3892d0ef5bb9cd29b7f2f0331ea5184ea" -uuid = "f5851436-0d7a-5f13-b9de-f02708fd171a" -version = "3.3.10+0" - -[[deps.FastRounding]] -deps = ["ErrorfreeArithmetic", "LinearAlgebra"] -git-tree-sha1 = "6344aa18f654196be82e62816935225b3b9abe44" -uuid = "fa42c844-2597-5d31-933b-ebd51ab2693f" -version = "0.3.1" - -[[deps.FileIO]] -deps = ["Pkg", "Requires", "UUIDs"] -git-tree-sha1 = "299dc33549f68299137e51e6d49a13b5b1da9673" -uuid = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" -version = "1.16.1" - -[[deps.FilePathsBase]] -deps = ["Compat", "Dates", "Mmap", "Printf", "Test", "UUIDs"] -git-tree-sha1 = "e27c4ebe80e8699540f2d6c805cc12203b614f12" -uuid = "48062228-2e41-5def-b9a4-89aafe57970f" -version = "0.9.20" - -[[deps.FileWatching]] -uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" - -[[deps.FillArrays]] -deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"] -git-tree-sha1 = "f372472e8672b1d993e93dada09e23139b509f9e" -uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "1.5.0" - -[[deps.FiniteDiff]] -deps = ["ArrayInterface", "LinearAlgebra", "Requires", "Setfield", "SparseArrays"] -git-tree-sha1 = "c6e4a1fbe73b31a3dea94b1da449503b8830c306" -uuid = "6a86dc24-6348-571c-b903-95158fe2bd41" -version = "2.21.1" - - [deps.FiniteDiff.extensions] - FiniteDiffBandedMatricesExt = "BandedMatrices" - FiniteDiffBlockBandedMatricesExt = "BlockBandedMatrices" - FiniteDiffStaticArraysExt = "StaticArrays" - - [deps.FiniteDiff.weakdeps] - BandedMatrices = "aae01518-5342-5314-be14-df237901396f" - BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" - StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" - -[[deps.FixedPointNumbers]] -deps = ["Statistics"] -git-tree-sha1 = "335bfdceacc84c5cdf16aadc768aa5ddfc5383cc" -uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93" -version = "0.8.4" - -[[deps.Fontconfig_jll]] -deps = ["Artifacts", "Bzip2_jll", "Expat_jll", "FreeType2_jll", "JLLWrappers", "Libdl", "Libuuid_jll", "Pkg", "Zlib_jll"] -git-tree-sha1 = "21efd19106a55620a188615da6d3d06cd7f6ee03" -uuid = "a3f928ae-7b40-5064-980b-68af3947d34b" -version = "2.13.93+0" - -[[deps.Formatting]] -deps = ["Printf"] -git-tree-sha1 = "8339d61043228fdd3eb658d86c926cb282ae72a8" -uuid = "59287772-0a20-5a39-b81b-1366585eb4c0" -version = "0.4.2" - -[[deps.ForwardDiff]] -deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "LogExpFunctions", "NaNMath", "Preferences", "Printf", "Random", "SpecialFunctions"] -git-tree-sha1 = "cf0fe81336da9fb90944683b8c41984b08793dad" -uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.36" -weakdeps = ["StaticArrays"] - - [deps.ForwardDiff.extensions] - ForwardDiffStaticArraysExt = "StaticArrays" - -[[deps.FreeType]] -deps = ["CEnum", "FreeType2_jll"] -git-tree-sha1 = "cabd77ab6a6fdff49bfd24af2ebe76e6e018a2b4" -uuid = "b38be410-82b0-50bf-ab77-7b57e271db43" -version = "4.0.0" - -[[deps.FreeType2_jll]] -deps = ["Artifacts", "Bzip2_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] -git-tree-sha1 = "87eb71354d8ec1a96d4a7636bd57a7347dde3ef9" -uuid = "d7e528f0-a631-5988-bf34-fe36492bcfd7" -version = "2.10.4+0" - -[[deps.FreeTypeAbstraction]] -deps = ["ColorVectorSpace", "Colors", "FreeType", "GeometryBasics"] -git-tree-sha1 = "38a92e40157100e796690421e34a11c107205c86" -uuid = "663a7486-cb36-511b-a19d-713bb74d65c9" -version = "0.10.0" - -[[deps.FriBidi_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "aa31987c2ba8704e23c6c8ba8a4f769d5d7e4f91" -uuid = "559328eb-81f9-559d-9380-de523a88c83c" -version = "1.0.10+0" - -[[deps.Future]] -deps = ["Random"] -uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820" - -[[deps.GLM]] -deps = ["Distributions", "LinearAlgebra", "Printf", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "StatsModels"] -git-tree-sha1 = "97829cfda0df99ddaeaafb5b370d6cab87b7013e" -uuid = "38e38edf-8417-5370-95a0-9cbb8c7f171a" -version = "1.8.3" - -[[deps.GPUArrays]] -deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"] -git-tree-sha1 = "2e57b4a4f9cc15e85a24d603256fe08e527f48d1" -uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -version = "8.8.1" - -[[deps.GPUArraysCore]] -deps = ["Adapt"] -git-tree-sha1 = "2d6ca471a6c7b536127afccfa7564b5b39227fe0" -uuid = "46192b85-c4d5-4398-a991-12ede77f4527" -version = "0.1.5" - -[[deps.GPUCompiler]] -deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Scratch", "TimerOutputs", "UUIDs"] -git-tree-sha1 = "72b2e3c2ba583d1a7aa35129e56cf92e07c083e3" -uuid = "61eb1bfa-7361-4325-ad38-22787b887f55" -version = "0.21.4" - -[[deps.GeoInterface]] -deps = ["Extents"] -git-tree-sha1 = "bb198ff907228523f3dee1070ceee63b9359b6ab" -uuid = "cf35fbd7-0cd7-5166-be24-54bfbe79505f" -version = "1.3.1" - -[[deps.GeometryBasics]] -deps = ["EarCut_jll", "Extents", "GeoInterface", "IterTools", "LinearAlgebra", "StaticArrays", "StructArrays", "Tables"] -git-tree-sha1 = "424a5a6ce7c5d97cca7bcc4eac551b97294c54af" -uuid = "5c1252a2-5f33-56bf-86c9-59e7332b4326" -version = "0.4.9" - -[[deps.Gettext_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "XML2_jll"] -git-tree-sha1 = "9b02998aba7bf074d14de89f9d37ca24a1a0b046" -uuid = "78b55507-aeef-58d4-861c-77aaff3498b1" -version = "0.21.0+0" - -[[deps.Glib_jll]] -deps = ["Artifacts", "Gettext_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Libiconv_jll", "Libmount_jll", "PCRE2_jll", "Pkg", "Zlib_jll"] -git-tree-sha1 = "d3b3624125c1474292d0d8ed0f65554ac37ddb23" -uuid = "7746bdde-850d-59dc-9ae8-88ece973131d" -version = "2.74.0+2" - -[[deps.Graphics]] -deps = ["Colors", "LinearAlgebra", "NaNMath"] -git-tree-sha1 = "d61890399bc535850c4bf08e4e0d3a7ad0f21cbd" -uuid = "a2bd30eb-e257-5431-a919-1863eab51364" -version = "1.1.2" - -[[deps.Graphite2_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "344bf40dcab1073aca04aa0df4fb092f920e4011" -uuid = "3b182d85-2403-5c21-9c21-1e1f0cc25472" -version = "1.3.14+0" - -[[deps.GridLayoutBase]] -deps = ["GeometryBasics", "InteractiveUtils", "Observables"] -git-tree-sha1 = "f57a64794b336d4990d90f80b147474b869b1bc4" -uuid = "3955a311-db13-416c-9275-1d80ed98e5e9" -version = "0.9.2" - -[[deps.Grisu]] -git-tree-sha1 = "53bb909d1151e57e2484c3d1b53e19552b887fb2" -uuid = "42e2da0e-8278-4e71-bc24-59509adca0fe" -version = "1.0.2" - -[[deps.HTTP]] -deps = ["Base64", "CodecZlib", "ConcurrentUtilities", "Dates", "ExceptionUnwrapping", "Logging", "LoggingExtras", "MbedTLS", "NetworkOptions", "OpenSSL", "Random", "SimpleBufferStream", "Sockets", "URIs", "UUIDs"] -git-tree-sha1 = "cb56ccdd481c0dd7f975ad2b3b62d9eda088f7e2" -uuid = "cd3eb016-35fb-5094-929b-558a96fad6f3" -version = "1.9.14" - -[[deps.HarfBuzz_jll]] -deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "Glib_jll", "Graphite2_jll", "JLLWrappers", "Libdl", "Libffi_jll", "Pkg"] -git-tree-sha1 = "129acf094d168394e80ee1dc4bc06ec835e510a3" -uuid = "2e76f6c2-a576-52d4-95c1-20adfe4de566" -version = "2.8.1+1" - -[[deps.HypergeometricFunctions]] -deps = ["DualNumbers", "LinearAlgebra", "OpenLibm_jll", "SpecialFunctions"] -git-tree-sha1 = "f218fe3736ddf977e0e772bc9a586b2383da2685" -uuid = "34004b35-14d8-5ef3-9330-4cdb6864b03a" -version = "0.3.23" - -[[deps.IOCapture]] -deps = ["Logging", "Random"] -git-tree-sha1 = "d75853a0bdbfb1ac815478bacd89cd27b550ace6" -uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" -version = "0.2.3" - -[[deps.ImageAxes]] -deps = ["AxisArrays", "ImageBase", "ImageCore", "Reexport", "SimpleTraits"] -git-tree-sha1 = "2e4520d67b0cef90865b3ef727594d2a58e0e1f8" -uuid = "2803e5a7-5153-5ecf-9a86-9b4c37f5f5ac" -version = "0.6.11" - -[[deps.ImageBase]] -deps = ["ImageCore", "Reexport"] -git-tree-sha1 = "b51bb8cae22c66d0f6357e3bcb6363145ef20835" -uuid = "c817782e-172a-44cc-b673-b171935fbb9e" -version = "0.1.5" - -[[deps.ImageCore]] -deps = ["AbstractFFTs", "ColorVectorSpace", "Colors", "FixedPointNumbers", "Graphics", "MappedArrays", "MosaicViews", "OffsetArrays", "PaddedViews", "Reexport"] -git-tree-sha1 = "acf614720ef026d38400b3817614c45882d75500" -uuid = "a09fc81d-aa75-5fe9-8630-4744c3626534" -version = "0.9.4" - -[[deps.ImageIO]] -deps = ["FileIO", "IndirectArrays", "JpegTurbo", "LazyModules", "Netpbm", "OpenEXR", "PNGFiles", "QOI", "Sixel", "TiffImages", "UUIDs"] -git-tree-sha1 = "bca20b2f5d00c4fbc192c3212da8fa79f4688009" -uuid = "82e4d734-157c-48bb-816b-45c225c6df19" -version = "0.6.7" - -[[deps.ImageMetadata]] -deps = ["AxisArrays", "ImageAxes", "ImageBase", "ImageCore"] -git-tree-sha1 = "355e2b974f2e3212a75dfb60519de21361ad3cb7" -uuid = "bc367c6b-8a6b-528e-b4bd-a4b897500b49" -version = "0.9.9" - -[[deps.Imath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "3d09a9f60edf77f8a4d99f9e015e8fbf9989605d" -uuid = "905a6f67-0a94-5f89-b386-d35d92009cd1" -version = "3.1.7+0" - -[[deps.IndirectArrays]] -git-tree-sha1 = "012e604e1c7458645cb8b436f8fba789a51b257f" -uuid = "9b13fd28-a010-5f03-acff-a1bbcff69959" -version = "1.0.0" - -[[deps.Inflate]] -git-tree-sha1 = "5cd07aab533df5170988219191dfad0519391428" -uuid = "d25df0c9-e2be-5dd7-82c8-3ad0b3e990b9" -version = "0.1.3" - -[[deps.InlineStrings]] -deps = ["Parsers"] -git-tree-sha1 = "9cc2baf75c6d09f9da536ddf58eb2f29dedaf461" -uuid = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48" -version = "1.4.0" - -[[deps.IntegerMathUtils]] -git-tree-sha1 = "b8ffb903da9f7b8cf695a8bead8e01814aa24b30" -uuid = "18e54dd8-cb9d-406c-a71d-865a43cbb235" -version = "0.1.2" - -[[deps.IntelOpenMP_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "ad37c091f7d7daf900963171600d7c1c5c3ede32" -uuid = "1d5cc7b8-4909-519e-a0f8-d0f5ad9712d0" -version = "2023.2.0+0" - -[[deps.InteractiveUtils]] -deps = ["Markdown"] -uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" - -[[deps.Interpolations]] -deps = ["Adapt", "AxisAlgorithms", "ChainRulesCore", "LinearAlgebra", "OffsetArrays", "Random", "Ratios", "Requires", "SharedArrays", "SparseArrays", "StaticArrays", "WoodburyMatrices"] -git-tree-sha1 = "721ec2cf720536ad005cb38f50dbba7b02419a15" -uuid = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59" -version = "0.14.7" - -[[deps.IntervalArithmetic]] -deps = ["CRlibm", "FastRounding", "LinearAlgebra", "Markdown", "Random", "RecipesBase", "RoundingEmulator", "SetRounding", "StaticArrays"] -git-tree-sha1 = "5ab7744289be503d76a944784bac3f2df7b809af" -uuid = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253" -version = "0.20.9" - -[[deps.IntervalSets]] -deps = ["Dates", "Random"] -git-tree-sha1 = "8e59ea773deee525c99a8018409f64f19fb719e6" -uuid = "8197267c-284f-5f27-9208-e0e47529a953" -version = "0.7.7" -weakdeps = ["Statistics"] - - [deps.IntervalSets.extensions] - IntervalSetsStatisticsExt = "Statistics" - -[[deps.InvertedIndices]] -git-tree-sha1 = "0dc7b50b8d436461be01300fd8cd45aa0274b038" -uuid = "41ab1584-1d38-5bbf-9106-f11c6c58b48f" -version = "1.3.0" - -[[deps.IrrationalConstants]] -git-tree-sha1 = "630b497eafcc20001bba38a4651b327dcfc491d2" -uuid = "92d709cd-6900-40b7-9082-c6be49f344b6" -version = "0.2.2" - -[[deps.Isoband]] -deps = ["isoband_jll"] -git-tree-sha1 = "f9b6d97355599074dc867318950adaa6f9946137" -uuid = "f1662d9f-8043-43de-a69a-05efc1cc6ff4" -version = "0.1.1" - -[[deps.IterTools]] -git-tree-sha1 = "4ced6667f9974fc5c5943fa5e2ef1ca43ea9e450" -uuid = "c8e1da08-722c-5040-9ed9-7db0dc04731e" -version = "1.8.0" - -[[deps.IterationControl]] -deps = ["EarlyStopping", "InteractiveUtils"] -git-tree-sha1 = "d7df9a6fdd82a8cfdfe93a94fcce35515be634da" -uuid = "b3c1a2ee-3fec-4384-bf48-272ea71de57c" -version = "0.5.3" - -[[deps.IterativeSolvers]] -deps = ["LinearAlgebra", "Printf", "Random", "RecipesBase", "SparseArrays"] -git-tree-sha1 = "1169632f425f79429f245113b775a0e3d121457c" -uuid = "42fd0dbc-a981-5370-80f2-aaf504508153" -version = "0.9.2" - -[[deps.IteratorInterfaceExtensions]] -git-tree-sha1 = "a3f24677c21f5bbe9d2a714f95dcd58337fb2856" -uuid = "82899510-4779-5014-852e-03e436cf321d" -version = "1.0.0" - -[[deps.JLLWrappers]] -deps = ["Preferences"] -git-tree-sha1 = "abc9885a7ca2052a736a600f7fa66209f96506e1" -uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210" -version = "1.4.1" - -[[deps.JSON]] -deps = ["Dates", "Mmap", "Parsers", "Unicode"] -git-tree-sha1 = "31e996f0a15c7b280ba9f76636b3ff9e2ae58c9a" -uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" -version = "0.21.4" - -[[deps.JpegTurbo]] -deps = ["CEnum", "FileIO", "ImageCore", "JpegTurbo_jll", "TOML"] -git-tree-sha1 = "327713faef2a3e5c80f96bf38d1fa26f7a6ae29e" -uuid = "b835a17e-a41a-41e7-81f0-2f016b05efe0" -version = "0.1.3" - -[[deps.JpegTurbo_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "6f2675ef130a300a112286de91973805fcc5ffbc" -uuid = "aacddb02-875f-59d6-b918-886e6ef4fbf8" -version = "2.1.91+0" - -[[deps.KernelAbstractions]] -deps = ["Adapt", "Atomix", "InteractiveUtils", "LinearAlgebra", "MacroTools", "PrecompileTools", "Requires", "SparseArrays", "StaticArrays", "UUIDs", "UnsafeAtomics", "UnsafeAtomicsLLVM"] -git-tree-sha1 = "4c5875e4c228247e1c2b087669846941fb6e0118" -uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -version = "0.9.8" - - [deps.KernelAbstractions.extensions] - EnzymeExt = "EnzymeCore" - - [deps.KernelAbstractions.weakdeps] - EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" - -[[deps.KernelDensity]] -deps = ["Distributions", "DocStringExtensions", "FFTW", "Interpolations", "StatsBase"] -git-tree-sha1 = "90442c50e202a5cdf21a7899c66b240fdef14035" -uuid = "5ab0869b-81aa-558d-bb23-cbf5423bbe9b" -version = "0.6.7" - -[[deps.LAME_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "f6250b16881adf048549549fba48b1161acdac8c" -uuid = "c1c5ebd0-6772-5130-a774-d5fcae4a789d" -version = "3.100.1+0" - -[[deps.LLVM]] -deps = ["CEnum", "LLVMExtra_jll", "Libdl", "Printf", "Unicode"] -git-tree-sha1 = "8695a49bfe05a2dc0feeefd06b4ca6361a018729" -uuid = "929cbde3-209d-540e-8aea-75f648917ca0" -version = "6.1.0" - -[[deps.LLVMExtra_jll]] -deps = ["Artifacts", "JLLWrappers", "LazyArtifacts", "Libdl", "TOML"] -git-tree-sha1 = "c35203c1e1002747da220ffc3c0762ce7754b08c" -uuid = "dad2f222-ce93-54a1-a47d-0025e8a3acab" -version = "0.0.23+0" - -[[deps.LLVMOpenMP_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "f689897ccbe049adb19a065c495e75f372ecd42b" -uuid = "1d63c593-3942-5779-bab2-d838dc0a180e" -version = "15.0.4+0" - -[[deps.LZO_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "e5b909bcf985c5e2605737d2ce278ed791b89be6" -uuid = "dd4b983a-f0e5-5f8d-a1b7-129d4a5fb1ac" -version = "2.10.1+0" - -[[deps.LaTeXStrings]] -git-tree-sha1 = "f2355693d6778a178ade15952b7ac47a4ff97996" -uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" -version = "1.3.0" - -[[deps.LatinHypercubeSampling]] -deps = ["Random", "StableRNGs", "StatsBase", "Test"] -git-tree-sha1 = "825289d43c753c7f1bf9bed334c253e9913997f8" -uuid = "a5e1c1ea-c99a-51d3-a14d-a9a37257b02d" -version = "1.9.0" - -[[deps.LazyArtifacts]] -deps = ["Artifacts", "Pkg"] -uuid = "4af54fe1-eca0-43a8-85a7-787d91b784e3" - -[[deps.LazyModules]] -git-tree-sha1 = "a560dd966b386ac9ae60bdd3a3d3a326062d3c3e" -uuid = "8cdb02fc-e678-4876-92c5-9defec4f444e" -version = "0.3.1" - -[[deps.LibCURL]] -deps = ["LibCURL_jll", "MozillaCACerts_jll"] -uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" -version = "0.6.3" - -[[deps.LibCURL_jll]] -deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] -uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" -version = "7.84.0+0" - -[[deps.LibGit2]] -deps = ["Base64", "NetworkOptions", "Printf", "SHA"] -uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" - -[[deps.LibSSH2_jll]] -deps = ["Artifacts", "Libdl", "MbedTLS_jll"] -uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" -version = "1.10.2+0" - -[[deps.Libdl]] -uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" - -[[deps.Libffi_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "0b4a5d71f3e5200a7dff793393e09dfc2d874290" -uuid = "e9f186c6-92d2-5b65-8a66-fee21dc1b490" -version = "3.2.2+1" - -[[deps.Libgcrypt_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgpg_error_jll", "Pkg"] -git-tree-sha1 = "64613c82a59c120435c067c2b809fc61cf5166ae" -uuid = "d4300ac3-e22c-5743-9152-c294e39db1e4" -version = "1.8.7+0" - -[[deps.Libgpg_error_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "c333716e46366857753e273ce6a69ee0945a6db9" -uuid = "7add5ba3-2f88-524e-9cd5-f83b8a55f7b8" -version = "1.42.0+0" - -[[deps.Libiconv_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "c7cb1f5d892775ba13767a87c7ada0b980ea0a71" -uuid = "94ce4f54-9a6c-5748-9c1c-f9c7231a4531" -version = "1.16.1+2" - -[[deps.Libmount_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "9c30530bf0effd46e15e0fdcf2b8636e78cbbd73" -uuid = "4b2f31a3-9ecc-558c-b454-b3730dcb73e9" -version = "2.35.0+0" - -[[deps.Libuuid_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "7f3efec06033682db852f8b3bc3c1d2b0a0ab066" -uuid = "38a345b3-de98-5d2b-a5d3-14cd9215e700" -version = "2.36.0+0" - -[[deps.LightXML]] -deps = ["Libdl", "XML2_jll"] -git-tree-sha1 = "e129d9391168c677cd4800f5c0abb1ed8cb3794f" -uuid = "9c8b4983-aa76-5018-a973-4c85ecc9e179" -version = "0.9.0" - -[[deps.LineSearches]] -deps = ["LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "Printf"] -git-tree-sha1 = "7bbea35cec17305fc70a0e5b4641477dc0789d9d" -uuid = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" -version = "7.2.0" - -[[deps.LinearAlgebra]] -deps = ["Libdl", "OpenBLAS_jll", "libblastrampoline_jll"] -uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" - -[[deps.LinearAlgebraX]] -deps = ["LinearAlgebra", "Mods", "Permutations", "Primes", "SimplePolynomials"] -git-tree-sha1 = "558a338f1eeabe933f9c2d4052aa7c2c707c3d52" -uuid = "9b3f67b0-2d00-526e-9884-9e4938f8fb88" -version = "0.1.12" - -[[deps.LinearMaps]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "6698ab5e662b47ffc63a82b2f43c1cee015cf80d" -uuid = "7a12625a-238d-50fd-b39a-03d52299707e" -version = "3.11.0" -weakdeps = ["ChainRulesCore", "SparseArrays", "Statistics"] - - [deps.LinearMaps.extensions] - LinearMapsChainRulesCoreExt = "ChainRulesCore" - LinearMapsSparseArraysExt = "SparseArrays" - LinearMapsStatisticsExt = "Statistics" - -[[deps.Literate]] -deps = ["Base64", "IOCapture", "JSON", "REPL"] -git-tree-sha1 = "a1a0d4ff9f785a2184baca7a5c89e664f144143d" -uuid = "98b081ad-f1c9-55d3-8b20-4c87d4299306" -version = "2.14.1" - -[[deps.LogExpFunctions]] -deps = ["DocStringExtensions", "IrrationalConstants", "LinearAlgebra"] -git-tree-sha1 = "c3ce8e7420b3a6e071e0fe4745f5d4300e37b13f" -uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" -version = "0.3.24" - - [deps.LogExpFunctions.extensions] - LogExpFunctionsChainRulesCoreExt = "ChainRulesCore" - LogExpFunctionsChangesOfVariablesExt = "ChangesOfVariables" - LogExpFunctionsInverseFunctionsExt = "InverseFunctions" - - [deps.LogExpFunctions.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.Logging]] -uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" - -[[deps.LoggingExtras]] -deps = ["Dates", "Logging"] -git-tree-sha1 = "cedb76b37bc5a6c702ade66be44f831fa23c681e" -uuid = "e6f89c97-d47a-5376-807f-9c37f3926c36" -version = "1.0.0" - -[[deps.LossFunctions]] -deps = ["Markdown", "Requires", "Statistics"] -git-tree-sha1 = "c2b72b61d2e3489b1f9cae3403226b21ec90c943" -uuid = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" -version = "0.11.0" -weakdeps = ["CategoricalArrays"] - - [deps.LossFunctions.extensions] - LossFunctionsCategoricalArraysExt = "CategoricalArrays" - -[[deps.MKL_jll]] -deps = ["Artifacts", "IntelOpenMP_jll", "JLLWrappers", "LazyArtifacts", "Libdl", "Pkg"] -git-tree-sha1 = "eb006abbd7041c28e0d16260e50a24f8f9104913" -uuid = "856f044c-d86e-5d09-b602-aeab76dc8ba7" -version = "2023.2.0+0" - -[[deps.MLJ]] -deps = ["CategoricalArrays", "ComputationalResources", "Distributed", "Distributions", "LinearAlgebra", "MLJBase", "MLJEnsembles", "MLJIteration", "MLJModels", "MLJTuning", "OpenML", "Pkg", "ProgressMeter", "Random", "ScientificTypes", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "d26cd777c711c332019b39445823cbb1f6cdb7e5" -uuid = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" -version = "0.19.2" - -[[deps.MLJBase]] -deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Dates", "DelimitedFiles", "Distributed", "Distributions", "InteractiveUtils", "InvertedIndices", "LinearAlgebra", "LossFunctions", "MLJModelInterface", "Missings", "OrderedCollections", "Parameters", "PrettyTables", "ProgressMeter", "Random", "ScientificTypes", "Serialization", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "2c9d6b9c627a80f6e6acbc6193026f455581fd04" -uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -version = "0.21.13" - -[[deps.MLJEnsembles]] -deps = ["CategoricalArrays", "CategoricalDistributions", "ComputationalResources", "Distributed", "Distributions", "MLJBase", "MLJModelInterface", "ProgressMeter", "Random", "ScientificTypesBase", "StatsBase"] -git-tree-sha1 = "95b306ef8108067d26dfde9ff3457d59911cc0d6" -uuid = "50ed68f4-41fd-4504-931a-ed422449fee0" -version = "0.3.3" - -[[deps.MLJGLMInterface]] -deps = ["Distributions", "GLM", "MLJModelInterface", "StatsModels", "Tables"] -git-tree-sha1 = "06aba1c96b19f31744f7e97d96fcf66b79739e05" -uuid = "caf8df21-4939-456d-ac9c-5fefbfb04c0c" -version = "0.3.5" - -[[deps.MLJIteration]] -deps = ["IterationControl", "MLJBase", "Random", "Serialization"] -git-tree-sha1 = "be6d5c71ab499a59e82d65e00a89ceba8732fcd5" -uuid = "614be32b-d00c-4edb-bd02-1eb411ab5e55" -version = "0.5.1" - -[[deps.MLJLinearModels]] -deps = ["DocStringExtensions", "IterativeSolvers", "LinearAlgebra", "LinearMaps", "MLJModelInterface", "Optim", "Parameters"] -git-tree-sha1 = "c92bf0ea37bf51e1ef0160069c572825819748b8" -uuid = "6ee0df7b-362f-4a72-a706-9e79364fb692" -version = "0.9.2" - -[[deps.MLJModelInterface]] -deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] -git-tree-sha1 = "c8b7e632d6754a5e36c0d94a4b466a5ba3a30128" -uuid = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -version = "1.8.0" - -[[deps.MLJModels]] -deps = ["CategoricalArrays", "CategoricalDistributions", "Combinatorics", "Dates", "Distances", "Distributions", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "Markdown", "OrderedCollections", "Parameters", "Pkg", "PrettyPrinting", "REPL", "Random", "RelocatableFolders", "ScientificTypes", "StatisticalTraits", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "2b49f04f70266a2b040eb46ece157c4f5c1b0c13" -uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7" -version = "0.16.10" - -[[deps.MLJTuning]] -deps = ["ComputationalResources", "Distributed", "Distributions", "LatinHypercubeSampling", "MLJBase", "ProgressMeter", "Random", "RecipesBase"] -git-tree-sha1 = "02688098bd77827b64ed8ad747c14f715f98cfc4" -uuid = "03970b2e-30c4-11ea-3135-d1576263f10f" -version = "0.7.4" - -[[deps.MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "42324d08725e200c23d4dfb549e0d5d89dede2d2" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.10" - -[[deps.Makie]] -deps = ["Animations", "Base64", "ColorBrewer", "ColorSchemes", "ColorTypes", "Colors", "Contour", "DelaunayTriangulation", "Distributions", "DocStringExtensions", "Downloads", "FFMPEG", "FileIO", "FixedPointNumbers", "Formatting", "FreeType", "FreeTypeAbstraction", "GeometryBasics", "GridLayoutBase", "ImageIO", "InteractiveUtils", "IntervalSets", "Isoband", "KernelDensity", "LaTeXStrings", "LinearAlgebra", "MacroTools", "MakieCore", "Markdown", "Match", "MathTeXEngine", "Observables", "OffsetArrays", "Packing", "PlotUtils", "PolygonOps", "PrecompileTools", "Printf", "REPL", "Random", "RelocatableFolders", "Setfield", "ShaderAbstractions", "Showoff", "SignedDistanceFields", "SparseArrays", "StableHashTraits", "Statistics", "StatsBase", "StatsFuns", "StructArrays", "TriplotBase", "UnicodeFun"] -git-tree-sha1 = "729640354756782c89adba8857085a69e19be7ab" -uuid = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" -version = "0.19.7" - -[[deps.MakieCore]] -deps = ["Observables"] -git-tree-sha1 = "87a85ff81583bd392642869557cb633532989517" -uuid = "20f20a25-4f0e-4fdf-b5d1-57303727442b" -version = "0.6.4" - -[[deps.MappedArrays]] -git-tree-sha1 = "2dab0221fe2b0f2cb6754eaa743cc266339f527e" -uuid = "dbb5928d-eab1-5f90-85c2-b9b0edb7c900" -version = "0.4.2" - -[[deps.Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[deps.Match]] -git-tree-sha1 = "1d9bc5c1a6e7ee24effb93f175c9342f9154d97f" -uuid = "7eb4fadd-790c-5f42-8a69-bfa0b872bfbf" -version = "1.2.0" - -[[deps.MathTeXEngine]] -deps = ["AbstractTrees", "Automa", "DataStructures", "FreeTypeAbstraction", "GeometryBasics", "LaTeXStrings", "REPL", "RelocatableFolders", "Test", "UnicodeFun"] -git-tree-sha1 = "8f52dbaa1351ce4cb847d95568cb29e62a307d93" -uuid = "0a4f8689-d25c-4efe-a92b-7142dfc1aa53" -version = "0.5.6" - -[[deps.MbedTLS]] -deps = ["Dates", "MbedTLS_jll", "MozillaCACerts_jll", "Random", "Sockets"] -git-tree-sha1 = "03a9b9718f5682ecb107ac9f7308991db4ce395b" -uuid = "739be429-bea8-5141-9913-cc70e7f3736d" -version = "1.1.7" - -[[deps.MbedTLS_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" -version = "2.28.2+0" - -[[deps.Missings]] -deps = ["DataAPI"] -git-tree-sha1 = "f66bdc5de519e8f8ae43bdc598782d35a25b1272" -uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "1.1.0" - -[[deps.Mmap]] -uuid = "a63ad114-7e13-5084-954f-fe012c677804" - -[[deps.Mods]] -git-tree-sha1 = "61be59e4daffff43a8cec04b5e0dc773cbb5db3a" -uuid = "7475f97c-0381-53b1-977b-4c60186c8d62" -version = "1.3.3" - -[[deps.MosaicViews]] -deps = ["MappedArrays", "OffsetArrays", "PaddedViews", "StackViews"] -git-tree-sha1 = "7b86a5d4d70a9f5cdf2dacb3cbe6d251d1a61dbe" -uuid = "e94cdb99-869f-56ef-bcf0-1ae2bcbe0389" -version = "0.3.4" - -[[deps.MozillaCACerts_jll]] -uuid = "14a3606d-f60d-562e-9121-12d972cd8159" -version = "2022.10.11" - -[[deps.Multisets]] -git-tree-sha1 = "8d852646862c96e226367ad10c8af56099b4047e" -uuid = "3b2b4ff1-bcff-5658-a3ee-dbcf1ce5ac09" -version = "0.4.4" - -[[deps.NLSolversBase]] -deps = ["DiffResults", "Distributed", "FiniteDiff", "ForwardDiff"] -git-tree-sha1 = "a0b464d183da839699f4c79e7606d9d186ec172c" -uuid = "d41bc354-129a-5804-8e4c-c37616107c6c" -version = "7.8.3" - -[[deps.NaNMath]] -deps = ["OpenLibm_jll"] -git-tree-sha1 = "0877504529a3e5c3343c6f8b4c0381e57e4387e4" -uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" -version = "1.0.2" - -[[deps.NearestNeighborModels]] -deps = ["Distances", "FillArrays", "InteractiveUtils", "LinearAlgebra", "MLJModelInterface", "NearestNeighbors", "Statistics", "StatsBase", "Tables"] -git-tree-sha1 = "e411143a8362926e4284a54e745972e939fbab78" -uuid = "636a865e-7cf4-491e-846c-de09b730eb36" -version = "0.2.3" - -[[deps.NearestNeighbors]] -deps = ["Distances", "StaticArrays"] -git-tree-sha1 = "2c3726ceb3388917602169bed973dbc97f1b51a8" -uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce" -version = "0.4.13" - -[[deps.Netpbm]] -deps = ["FileIO", "ImageCore", "ImageMetadata"] -git-tree-sha1 = "d92b107dbb887293622df7697a2223f9f8176fcd" -uuid = "f09324ee-3d7c-5217-9330-fc30815ba969" -version = "1.1.1" - -[[deps.NetworkLayout]] -deps = ["GeometryBasics", "LinearAlgebra", "Random", "Requires", "SparseArrays", "StaticArrays"] -git-tree-sha1 = "2bfd8cd7fba3e46ce48139ae93904ee848153660" -uuid = "46757867-2c16-5918-afeb-47bfcb05e46a" -version = "0.4.5" - -[[deps.NetworkOptions]] -uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" -version = "1.2.0" - -[[deps.Observables]] -git-tree-sha1 = "6862738f9796b3edc1c09d0890afce4eca9e7e93" -uuid = "510215fc-4207-5dde-b226-833fc4488ee2" -version = "0.5.4" - -[[deps.OffsetArrays]] -deps = ["Adapt"] -git-tree-sha1 = "2ac17d29c523ce1cd38e27785a7d23024853a4bb" -uuid = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" -version = "1.12.10" - -[[deps.Ogg_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "887579a3eb005446d514ab7aeac5d1d027658b8f" -uuid = "e7412a2a-1a6e-54c0-be00-318e2571c051" -version = "1.3.5+1" - -[[deps.OpenBLAS_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] -uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.21+4" - -[[deps.OpenEXR]] -deps = ["Colors", "FileIO", "OpenEXR_jll"] -git-tree-sha1 = "327f53360fdb54df7ecd01e96ef1983536d1e633" -uuid = "52e1d378-f018-4a11-a4be-720524705ac7" -version = "0.3.2" - -[[deps.OpenEXR_jll]] -deps = ["Artifacts", "Imath_jll", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "a4ca623df1ae99d09bc9868b008262d0c0ac1e4f" -uuid = "18a262bb-aa17-5467-a713-aee519bc75cb" -version = "3.1.4+0" - -[[deps.OpenLibm_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "05823500-19ac-5b8b-9628-191a04bc5112" -version = "0.8.1+0" - -[[deps.OpenML]] -deps = ["ARFFFiles", "HTTP", "JSON", "Markdown", "Pkg", "Scratch"] -git-tree-sha1 = "6efb039ae888699d5a74fb593f6f3e10c7193e33" -uuid = "8b6db2d4-7670-4922-a472-f9537c81ab66" -version = "0.3.1" - -[[deps.OpenSSL]] -deps = ["BitFlags", "Dates", "MozillaCACerts_jll", "OpenSSL_jll", "Sockets"] -git-tree-sha1 = "51901a49222b09e3743c65b8847687ae5fc78eb2" -uuid = "4d8831e6-92b7-49fb-bdf8-b643e874388c" -version = "1.4.1" - -[[deps.OpenSSL_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e78db7bd5c26fc5a6911b50a47ee302219157ea8" -uuid = "458c3c95-2e84-50aa-8efc-19380b2a3a95" -version = "3.0.10+0" - -[[deps.OpenSpecFun_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1" -uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" -version = "0.5.5+0" - -[[deps.Optim]] -deps = ["Compat", "FillArrays", "ForwardDiff", "LineSearches", "LinearAlgebra", "NLSolversBase", "NaNMath", "Parameters", "PositiveFactorizations", "Printf", "SparseArrays", "StatsBase"] -git-tree-sha1 = "e3a6546c1577bfd701771b477b794a52949e7594" -uuid = "429524aa-4258-5aef-a3af-852621145aeb" -version = "1.7.6" - -[[deps.Opus_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "51a08fb14ec28da2ec7a927c4337e4332c2a4720" -uuid = "91d4177d-7536-5919-b921-800302f37372" -version = "1.3.2+0" - -[[deps.OrderedCollections]] -git-tree-sha1 = "2e73fe17cac3c62ad1aebe70d44c963c3cfdc3e3" -uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.6.2" - -[[deps.PCRE2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "efcefdf7-47ab-520b-bdef-62a2eaa19f15" -version = "10.42.0+0" - -[[deps.PDMats]] -deps = ["LinearAlgebra", "SparseArrays", "SuiteSparse"] -git-tree-sha1 = "67eae2738d63117a196f497d7db789821bce61d1" -uuid = "90014a1f-27ba-587c-ab20-58faa44d9150" -version = "0.11.17" - -[[deps.PNGFiles]] -deps = ["Base64", "CEnum", "ImageCore", "IndirectArrays", "OffsetArrays", "libpng_jll"] -git-tree-sha1 = "9b02b27ac477cad98114584ff964e3052f656a0f" -uuid = "f57f5aa1-a3ce-4bc8-8ab9-96f992907883" -version = "0.4.0" - -[[deps.Packing]] -deps = ["GeometryBasics"] -git-tree-sha1 = "ec3edfe723df33528e085e632414499f26650501" -uuid = "19eb6ba3-879d-56ad-ad62-d5c202156566" -version = "0.5.0" - -[[deps.PaddedViews]] -deps = ["OffsetArrays"] -git-tree-sha1 = "0fac6313486baae819364c52b4f483450a9d793f" -uuid = "5432bcbf-9aad-5242-b902-cca2824c8663" -version = "0.5.12" - -[[deps.Pango_jll]] -deps = ["Artifacts", "Cairo_jll", "Fontconfig_jll", "FreeType2_jll", "FriBidi_jll", "Glib_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "84a314e3926ba9ec66ac097e3635e270986b0f10" -uuid = "36c8627f-9965-5494-a995-c6b170f724f3" -version = "1.50.9+0" - -[[deps.Parameters]] -deps = ["OrderedCollections", "UnPack"] -git-tree-sha1 = "34c0e9ad262e5f7fc75b10a9952ca7692cfc5fbe" -uuid = "d96e819e-fc66-5662-9728-84c9c7592b0a" -version = "0.12.3" - -[[deps.Parsers]] -deps = ["Dates", "PrecompileTools", "UUIDs"] -git-tree-sha1 = "716e24b21538abc91f6205fd1d8363f39b442851" -uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "2.7.2" - -[[deps.Permutations]] -deps = ["Combinatorics", "LinearAlgebra", "Random"] -git-tree-sha1 = "6e6cab1c54ae2382bcc48866b91cf949cea703a1" -uuid = "2ae35dd2-176d-5d53-8349-f30d82d94d4f" -version = "0.4.16" - -[[deps.Pixman_jll]] -deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "LLVMOpenMP_jll", "Libdl"] -git-tree-sha1 = "64779bc4c9784fee475689a1752ef4d5747c5e87" -uuid = "30392449-352a-5448-841d-b1acce4e97dc" -version = "0.42.2+0" - -[[deps.Pkg]] -deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] -uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -version = "1.9.2" - -[[deps.PkgVersion]] -deps = ["Pkg"] -git-tree-sha1 = "f6cf8e7944e50901594838951729a1861e668cb8" -uuid = "eebad327-c553-4316-9ea0-9fa01ccd7688" -version = "0.3.2" - -[[deps.PlotUtils]] -deps = ["ColorSchemes", "Colors", "Dates", "PrecompileTools", "Printf", "Random", "Reexport", "Statistics"] -git-tree-sha1 = "f92e1315dadf8c46561fb9396e525f7200cdc227" -uuid = "995b91a9-d308-5afd-9ec6-746e21dbc043" -version = "1.3.5" - -[[deps.PolygonOps]] -git-tree-sha1 = "77b3d3605fc1cd0b42d95eba87dfcd2bf67d5ff6" -uuid = "647866c9-e3ac-4575-94e7-e3d426903924" -version = "0.1.2" - -[[deps.Polynomials]] -deps = ["LinearAlgebra", "RecipesBase"] -git-tree-sha1 = "3aa2bb4982e575acd7583f01531f241af077b163" -uuid = "f27b6e38-b328-58d1-80ce-0feddd5e7a45" -version = "3.2.13" - - [deps.Polynomials.extensions] - PolynomialsChainRulesCoreExt = "ChainRulesCore" - PolynomialsMakieCoreExt = "MakieCore" - PolynomialsMutableArithmeticsExt = "MutableArithmetics" - - [deps.Polynomials.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - MakieCore = "20f20a25-4f0e-4fdf-b5d1-57303727442b" - MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" - -[[deps.PooledArrays]] -deps = ["DataAPI", "Future"] -git-tree-sha1 = "a6062fe4063cdafe78f4a0a81cfffb89721b30e7" -uuid = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" -version = "1.4.2" - -[[deps.PositiveFactorizations]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "17275485f373e6673f7e7f97051f703ed5b15b20" -uuid = "85a6dd25-e78a-55b7-8502-1745935b8125" -version = "0.2.4" - -[[deps.PrecompileTools]] -deps = ["Preferences"] -git-tree-sha1 = "9673d39decc5feece56ef3940e5dafba15ba0f81" -uuid = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -version = "1.1.2" - -[[deps.Preferences]] -deps = ["TOML"] -git-tree-sha1 = "7eb1686b4f04b82f96ed7a4ea5890a4f0c7a09f1" -uuid = "21216c6a-2e73-6563-6e65-726566657250" -version = "1.4.0" - -[[deps.PrettyPrinting]] -git-tree-sha1 = "22a601b04a154ca38867b991d5017469dc75f2db" -uuid = "54e16d92-306c-5ea0-a30b-337be88ac337" -version = "0.4.1" - -[[deps.PrettyTables]] -deps = ["Crayons", "LaTeXStrings", "Markdown", "Printf", "Reexport", "StringManipulation", "Tables"] -git-tree-sha1 = "ee094908d720185ddbdc58dbe0c1cbe35453ec7a" -uuid = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" -version = "2.2.7" - -[[deps.Primes]] -deps = ["IntegerMathUtils"] -git-tree-sha1 = "4c9f306e5d6603ae203c2000dd460d81a5251489" -uuid = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae" -version = "0.5.4" - -[[deps.Printf]] -deps = ["Unicode"] -uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" - -[[deps.ProgressMeter]] -deps = ["Distributed", "Printf"] -git-tree-sha1 = "d7a7aef8f8f2d537104f170139553b14dfe39fe9" -uuid = "92933f4c-e287-5a05-a399-4b506db050ca" -version = "1.7.2" - -[[deps.QOI]] -deps = ["ColorTypes", "FileIO", "FixedPointNumbers"] -git-tree-sha1 = "18e8f4d1426e965c7b532ddd260599e1510d26ce" -uuid = "4b34888f-f399-49d4-9bb3-47ed5cae4e65" -version = "1.0.0" - -[[deps.QuadGK]] -deps = ["DataStructures", "LinearAlgebra"] -git-tree-sha1 = "6ec7ac8412e83d57e313393220879ede1740f9ee" -uuid = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" -version = "2.8.2" - -[[deps.REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] -uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" - -[[deps.Random]] -deps = ["SHA", "Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[deps.Random123]] -deps = ["Random", "RandomNumbers"] -git-tree-sha1 = "552f30e847641591ba3f39fd1bed559b9deb0ef3" -uuid = "74087812-796a-5b5d-8853-05524746bad3" -version = "1.6.1" - -[[deps.RandomNumbers]] -deps = ["Random", "Requires"] -git-tree-sha1 = "043da614cc7e95c703498a491e2c21f58a2b8111" -uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143" -version = "1.5.3" - -[[deps.RangeArrays]] -git-tree-sha1 = "b9039e93773ddcfc828f12aadf7115b4b4d225f5" -uuid = "b3c3ace0-ae52-54e7-9d0b-2c1406fd6b9d" -version = "0.3.2" - -[[deps.Ratios]] -deps = ["Requires"] -git-tree-sha1 = "1342a47bf3260ee108163042310d26f2be5ec90b" -uuid = "c84ed2f1-dad5-54f0-aa8e-dbefe2724439" -version = "0.4.5" -weakdeps = ["FixedPointNumbers"] - - [deps.Ratios.extensions] - RatiosFixedPointNumbersExt = "FixedPointNumbers" - -[[deps.RecipesBase]] -deps = ["PrecompileTools"] -git-tree-sha1 = "5c3d09cc4f31f5fc6af001c250bf1278733100ff" -uuid = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" -version = "1.3.4" - -[[deps.Reexport]] -git-tree-sha1 = "45e428421666073eab6f2da5c9d310d99bb12f9b" -uuid = "189a3867-3050-52da-a836-e630ba90ab69" -version = "1.2.2" - -[[deps.RelocatableFolders]] -deps = ["SHA", "Scratch"] -git-tree-sha1 = "90bc7a7c96410424509e4263e277e43250c05691" -uuid = "05181044-ff0b-4ac5-8273-598c1e38db00" -version = "1.0.0" - -[[deps.Requires]] -deps = ["UUIDs"] -git-tree-sha1 = "838a3a4188e2ded87a4f9f184b4b0d78a1e91cb7" -uuid = "ae029012-a4dd-5104-9daa-d747884805df" -version = "1.3.0" - -[[deps.RingLists]] -deps = ["Random"] -git-tree-sha1 = "9712ebc42e91850f35272b48eb840e60c0270ec0" -uuid = "286e9d63-9694-5540-9e3c-4e6708fa07b2" -version = "0.2.7" - -[[deps.Rmath]] -deps = ["Random", "Rmath_jll"] -git-tree-sha1 = "f65dcb5fa46aee0cf9ed6274ccbd597adc49aa7b" -uuid = "79098fc4-a85e-5d69-aa6a-4863f24498fa" -version = "0.7.1" - -[[deps.Rmath_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "6ed52fdd3382cf21947b15e8870ac0ddbff736da" -uuid = "f50d1b31-88e8-58de-be2c-1cc44531875f" -version = "0.4.0+0" - -[[deps.RoundingEmulator]] -git-tree-sha1 = "40b9edad2e5287e05bd413a38f61a8ff55b9557b" -uuid = "5eaf0fd0-dfba-4ccb-bf02-d820a40db705" -version = "0.2.1" - -[[deps.SHA]] -uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" -version = "0.7.0" - -[[deps.SQLite]] -deps = ["DBInterface", "Random", "SQLite_jll", "Serialization", "Tables", "WeakRefStrings"] -git-tree-sha1 = "eb9a473c9b191ced349d04efa612ec9f39c087ea" -uuid = "0aa819cd-b072-5ff4-a722-6bc24af294d9" -version = "1.6.0" - -[[deps.SQLite_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Zlib_jll"] -git-tree-sha1 = "4619dd3363610d94fb42a95a6dc35b526a26d0ef" -uuid = "76ed43ae-9a5d-5a62-8c75-30186b810ce8" -version = "3.42.0+0" - -[[deps.ScientificTypes]] -deps = ["CategoricalArrays", "ColorTypes", "Dates", "Distributions", "PrettyTables", "Reexport", "ScientificTypesBase", "StatisticalTraits", "Tables"] -git-tree-sha1 = "75ccd10ca65b939dab03b812994e571bf1e3e1da" -uuid = "321657f4-b219-11e9-178b-2701a2544e81" -version = "3.0.2" - -[[deps.ScientificTypesBase]] -git-tree-sha1 = "a8e18eb383b5ecf1b5e6fc237eb39255044fd92b" -uuid = "30f210dd-8aff-4c5f-94ba-8e64358c1161" -version = "3.0.0" - -[[deps.Scratch]] -deps = ["Dates"] -git-tree-sha1 = "30449ee12237627992a99d5e30ae63e4d78cd24a" -uuid = "6c6a2e73-6563-6170-7368-637461726353" -version = "1.2.0" - -[[deps.SentinelArrays]] -deps = ["Dates", "Random"] -git-tree-sha1 = "04bdff0b09c65ff3e06a05e3eb7b120223da3d39" -uuid = "91c51154-3ec4-41a3-a24f-3f23e20d615c" -version = "1.4.0" - -[[deps.Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" - -[[deps.SetRounding]] -git-tree-sha1 = "d7a25e439d07a17b7cdf97eecee504c50fedf5f6" -uuid = "3cc68bcd-71a2-5612-b932-767ffbe40ab0" -version = "0.2.1" - -[[deps.Setfield]] -deps = ["ConstructionBase", "Future", "MacroTools", "StaticArraysCore"] -git-tree-sha1 = "e2cc6d8c88613c05e1defb55170bf5ff211fbeac" -uuid = "efcf1570-3423-57d1-acb7-fd33fddbac46" -version = "1.1.1" - -[[deps.ShaderAbstractions]] -deps = ["ColorTypes", "FixedPointNumbers", "GeometryBasics", "LinearAlgebra", "Observables", "StaticArrays", "StructArrays", "Tables"] -git-tree-sha1 = "0d15c3e7b2003f4451714f08ffec2b77badc2dc4" -uuid = "65257c39-d410-5151-9873-9b3e5be5013e" -version = "0.3.0" - -[[deps.SharedArrays]] -deps = ["Distributed", "Mmap", "Random", "Serialization"] -uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383" - -[[deps.ShiftedArrays]] -git-tree-sha1 = "503688b59397b3307443af35cd953a13e8005c16" -uuid = "1277b4bf-5013-50f5-be3d-901d8477a67a" -version = "2.0.0" - -[[deps.Showoff]] -deps = ["Dates", "Grisu"] -git-tree-sha1 = "91eddf657aca81df9ae6ceb20b959ae5653ad1de" -uuid = "992d4aef-0814-514b-bc4d-f2e9a6c4116f" -version = "1.0.3" - -[[deps.SignedDistanceFields]] -deps = ["Random", "Statistics", "Test"] -git-tree-sha1 = "d263a08ec505853a5ff1c1ebde2070419e3f28e9" -uuid = "73760f76-fbc4-59ce-8f25-708e95d2df96" -version = "0.4.0" - -[[deps.SimpleBufferStream]] -git-tree-sha1 = "874e8867b33a00e784c8a7e4b60afe9e037b74e1" -uuid = "777ac1f9-54b0-4bf8-805c-2214025038e7" -version = "1.1.0" - -[[deps.SimpleGraphs]] -deps = ["AbstractLattices", "Combinatorics", "DataStructures", "IterTools", "LightXML", "LinearAlgebra", "LinearAlgebraX", "Optim", "Primes", "Random", "RingLists", "SimplePartitions", "SimplePolynomials", "SimpleRandom", "SparseArrays", "Statistics"] -git-tree-sha1 = "b608903049d11cc557c45e03b3a53e9260579c19" -uuid = "55797a34-41de-5266-9ec1-32ac4eb504d3" -version = "0.8.4" - -[[deps.SimplePartitions]] -deps = ["AbstractLattices", "DataStructures", "Permutations"] -git-tree-sha1 = "dcc02923a53f316ab97da8ef3136e80b4543dbf1" -uuid = "ec83eff0-a5b5-5643-ae32-5cbf6eedec9d" -version = "0.3.0" - -[[deps.SimplePolynomials]] -deps = ["Mods", "Multisets", "Polynomials", "Primes"] -git-tree-sha1 = "9f1b1f47279018b35316c62e829af1f3f6725a47" -uuid = "cc47b68c-3164-5771-a705-2bc0097375a0" -version = "0.2.13" - -[[deps.SimpleRandom]] -deps = ["Distributions", "LinearAlgebra", "Random"] -git-tree-sha1 = "3a6fb395e37afab81aeea85bae48a4db5cd7244a" -uuid = "a6525b86-64cd-54fa-8f65-62fc48bdc0e8" -version = "0.3.1" - -[[deps.SimpleTraits]] -deps = ["InteractiveUtils", "MacroTools"] -git-tree-sha1 = "5d7e3f4e11935503d3ecaf7186eac40602e7d231" -uuid = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" -version = "0.9.4" - -[[deps.Sixel]] -deps = ["Dates", "FileIO", "ImageCore", "IndirectArrays", "OffsetArrays", "REPL", "libsixel_jll"] -git-tree-sha1 = "2da10356e31327c7096832eb9cd86307a50b1eb6" -uuid = "45858cf5-a6b0-47a3-bbea-62219f50df47" -version = "0.1.3" - -[[deps.Sockets]] -uuid = "6462fe0b-24de-5631-8697-dd941f90decc" - -[[deps.SortingAlgorithms]] -deps = ["DataStructures"] -git-tree-sha1 = "c60ec5c62180f27efea3ba2908480f8055e17cee" -uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c" -version = "1.1.1" - -[[deps.SparseArrays]] -deps = ["Libdl", "LinearAlgebra", "Random", "Serialization", "SuiteSparse_jll"] -uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" - -[[deps.SpecialFunctions]] -deps = ["IrrationalConstants", "LogExpFunctions", "OpenLibm_jll", "OpenSpecFun_jll"] -git-tree-sha1 = "7beb031cf8145577fbccacd94b8a8f4ce78428d3" -uuid = "276daf66-3868-5448-9aa4-cd146d93841b" -version = "2.3.0" -weakdeps = ["ChainRulesCore"] - - [deps.SpecialFunctions.extensions] - SpecialFunctionsChainRulesCoreExt = "ChainRulesCore" - -[[deps.StableHashTraits]] -deps = ["CRC32c", "Compat", "Dates", "SHA", "Tables", "TupleTools", "UUIDs"] -git-tree-sha1 = "0b8b801b8f03a329a4e86b44c5e8a7d7f4fe10a3" -uuid = "c5dd0088-6c3f-4803-b00e-f31a60c170fa" -version = "0.3.1" - -[[deps.StableRNGs]] -deps = ["Random", "Test"] -git-tree-sha1 = "3be7d49667040add7ee151fefaf1f8c04c8c8276" -uuid = "860ef19b-820b-49d6-a774-d7a799459cd3" -version = "1.0.0" - -[[deps.StackViews]] -deps = ["OffsetArrays"] -git-tree-sha1 = "46e589465204cd0c08b4bd97385e4fa79a0c770c" -uuid = "cae243ae-269e-4f55-b966-ac2d0dc13c15" -version = "0.1.1" - -[[deps.StaticArrays]] -deps = ["LinearAlgebra", "Random", "StaticArraysCore"] -git-tree-sha1 = "9cabadf6e7cd2349b6cf49f1915ad2028d65e881" -uuid = "90137ffa-7385-5640-81b9-e52037218182" -version = "1.6.2" -weakdeps = ["Statistics"] - - [deps.StaticArrays.extensions] - StaticArraysStatisticsExt = "Statistics" - -[[deps.StaticArraysCore]] -git-tree-sha1 = "36b3d696ce6366023a0ea192b4cd442268995a0d" -uuid = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" -version = "1.4.2" - -[[deps.StatisticalTraits]] -deps = ["ScientificTypesBase"] -git-tree-sha1 = "30b9236691858e13f167ce829490a68e1a597782" -uuid = "64bff920-2084-43da-a3e6-9bb72801c0c9" -version = "3.2.0" - -[[deps.Statistics]] -deps = ["LinearAlgebra", "SparseArrays"] -uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -version = "1.9.0" - -[[deps.StatsAPI]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "45a7769a04a3cf80da1c1c7c60caf932e6f4c9f7" -uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0" -version = "1.6.0" - -[[deps.StatsBase]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "LogExpFunctions", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"] -git-tree-sha1 = "75ebe04c5bed70b91614d684259b661c9e6274a4" -uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -version = "0.34.0" - -[[deps.StatsFuns]] -deps = ["HypergeometricFunctions", "IrrationalConstants", "LogExpFunctions", "Reexport", "Rmath", "SpecialFunctions"] -git-tree-sha1 = "f625d686d5a88bcd2b15cd81f18f98186fdc0c9a" -uuid = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -version = "1.3.0" - - [deps.StatsFuns.extensions] - StatsFunsChainRulesCoreExt = "ChainRulesCore" - StatsFunsInverseFunctionsExt = "InverseFunctions" - - [deps.StatsFuns.weakdeps] - ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" - InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" - -[[deps.StatsModels]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Printf", "REPL", "ShiftedArrays", "SparseArrays", "StatsBase", "StatsFuns", "Tables"] -git-tree-sha1 = "8cc7a5385ecaa420f0b3426f9b0135d0df0638ed" -uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d" -version = "0.7.2" - -[[deps.StringManipulation]] -git-tree-sha1 = "46da2434b41f41ac3594ee9816ce5541c6096123" -uuid = "892a3eda-7b42-436c-8928-eab12a02cf0e" -version = "0.3.0" - -[[deps.StructArrays]] -deps = ["Adapt", "DataAPI", "GPUArraysCore", "StaticArraysCore", "Tables"] -git-tree-sha1 = "521a0e828e98bb69042fec1809c1b5a680eb7389" -uuid = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" -version = "0.6.15" - -[[deps.SuiteSparse]] -deps = ["Libdl", "LinearAlgebra", "Serialization", "SparseArrays"] -uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" - -[[deps.SuiteSparse_jll]] -deps = ["Artifacts", "Libdl", "Pkg", "libblastrampoline_jll"] -uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "5.10.1+6" - -[[deps.TOML]] -deps = ["Dates"] -uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -version = "1.0.3" - -[[deps.TableTraits]] -deps = ["IteratorInterfaceExtensions"] -git-tree-sha1 = "c06b2f539df1c6efa794486abfb6ed2022561a39" -uuid = "3783bdb8-4a98-5b6b-af9a-565f29a5fe9c" -version = "1.0.1" - -[[deps.Tables]] -deps = ["DataAPI", "DataValueInterfaces", "IteratorInterfaceExtensions", "LinearAlgebra", "OrderedCollections", "TableTraits", "Test"] -git-tree-sha1 = "1544b926975372da01227b382066ab70e574a3ec" -uuid = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -version = "1.10.1" - -[[deps.Tar]] -deps = ["ArgTools", "SHA"] -uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" -version = "1.10.0" - -[[deps.TensorCore]] -deps = ["LinearAlgebra"] -git-tree-sha1 = "1feb45f88d133a655e001435632f019a9a1bcdb6" -uuid = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50" -version = "0.1.1" - -[[deps.Test]] -deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] -uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[[deps.TiffImages]] -deps = ["ColorTypes", "DataStructures", "DocStringExtensions", "FileIO", "FixedPointNumbers", "IndirectArrays", "Inflate", "Mmap", "OffsetArrays", "PkgVersion", "ProgressMeter", "UUIDs"] -git-tree-sha1 = "8621f5c499a8aa4aa970b1ae381aae0ef1576966" -uuid = "731e570b-9d59-4bfa-96dc-6df516fadf69" -version = "0.6.4" - -[[deps.TimerOutputs]] -deps = ["ExprTools", "Printf"] -git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" -uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.23" - -[[deps.TranscodingStreams]] -deps = ["Random", "Test"] -git-tree-sha1 = "9a6ae7ed916312b41236fcef7e0af564ef934769" -uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa" -version = "0.9.13" - -[[deps.TriplotBase]] -git-tree-sha1 = "4d4ed7f294cda19382ff7de4c137d24d16adc89b" -uuid = "981d1d27-644d-49a2-9326-4793e63143c3" -version = "0.1.0" - -[[deps.TupleTools]] -git-tree-sha1 = "3c712976c47707ff893cf6ba4354aa14db1d8938" -uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" -version = "1.3.0" - -[[deps.URIs]] -git-tree-sha1 = "b7a5e99f24892b6824a954199a45e9ffcc1c70f0" -uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" -version = "1.5.0" - -[[deps.UUIDs]] -deps = ["Random", "SHA"] -uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" - -[[deps.UnPack]] -git-tree-sha1 = "387c1f73762231e86e0c9c5443ce3b4a0a9a0c2b" -uuid = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" -version = "1.0.2" - -[[deps.Unicode]] -uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" - -[[deps.UnicodeFun]] -deps = ["REPL"] -git-tree-sha1 = "53915e50200959667e78a92a418594b428dffddf" -uuid = "1cfade01-22cf-5700-b092-accc4b62d6e1" -version = "0.4.1" - -[[deps.UnsafeAtomics]] -git-tree-sha1 = "6331ac3440856ea1988316b46045303bef658278" -uuid = "013be700-e6cd-48c3-b4a1-df204f14c38f" -version = "0.2.1" - -[[deps.UnsafeAtomicsLLVM]] -deps = ["LLVM", "UnsafeAtomics"] -git-tree-sha1 = "323e3d0acf5e78a56dfae7bd8928c989b4f3083e" -uuid = "d80eeb9a-aca5-4d75-85e5-170c8b632249" -version = "0.1.3" - -[[deps.WeakRefStrings]] -deps = ["DataAPI", "InlineStrings", "Parsers"] -git-tree-sha1 = "b1be2855ed9ed8eac54e5caff2afcdb442d52c23" -uuid = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" -version = "1.4.2" - -[[deps.WoodburyMatrices]] -deps = ["LinearAlgebra", "SparseArrays"] -git-tree-sha1 = "de67fa59e33ad156a590055375a30b23c40299d3" -uuid = "efce3f68-66dc-5838-9240-27a6d6f5f9b6" -version = "0.5.5" - -[[deps.WorkerUtilities]] -git-tree-sha1 = "cd1659ba0d57b71a464a29e64dbc67cfe83d54e7" -uuid = "76eceee3-57b5-4d4a-8e66-0e911cebbf60" -version = "1.6.1" - -[[deps.XML2_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libiconv_jll", "Pkg", "Zlib_jll"] -git-tree-sha1 = "93c41695bc1c08c46c5899f4fe06d6ead504bb73" -uuid = "02c8fc9c-b97f-50b9-bbe4-9be30ff0a78a" -version = "2.10.3+0" - -[[deps.XSLT_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Libgcrypt_jll", "Libgpg_error_jll", "Libiconv_jll", "Pkg", "XML2_jll", "Zlib_jll"] -git-tree-sha1 = "91844873c4085240b95e795f692c4cec4d805f8a" -uuid = "aed1982a-8fda-507f-9586-7b0439959a61" -version = "1.1.34+0" - -[[deps.Xorg_libX11_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Xorg_libxcb_jll", "Xorg_xtrans_jll"] -git-tree-sha1 = "afead5aba5aa507ad5a3bf01f58f82c8d1403495" -uuid = "4f6342f7-b3d2-589e-9d20-edeb45f2b2bc" -version = "1.8.6+0" - -[[deps.Xorg_libXau_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "6035850dcc70518ca32f012e46015b9beeda49d8" -uuid = "0c0b7dd1-d40b-584c-a123-a41640f87eec" -version = "1.0.11+0" - -[[deps.Xorg_libXdmcp_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "34d526d318358a859d7de23da945578e8e8727b7" -uuid = "a3789734-cfe1-5b06-b2d0-1dd0d9d62d05" -version = "1.1.4+0" - -[[deps.Xorg_libXext_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] -git-tree-sha1 = "b7c0aa8c376b31e4852b360222848637f481f8c3" -uuid = "1082639a-0dae-5f34-9b06-72781eeb8cb3" -version = "1.3.4+4" - -[[deps.Xorg_libXrender_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Xorg_libX11_jll"] -git-tree-sha1 = "19560f30fd49f4d4efbe7002a1037f8c43d43b96" -uuid = "ea2f1a96-1ddc-540d-b46f-429655e07cfa" -version = "0.9.10+4" - -[[deps.Xorg_libpthread_stubs_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "8fdda4c692503d44d04a0603d9ac0982054635f9" -uuid = "14d82f49-176c-5ed1-bb49-ad3f5cbd8c74" -version = "0.1.1+0" - -[[deps.Xorg_libxcb_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "XSLT_jll", "Xorg_libXau_jll", "Xorg_libXdmcp_jll", "Xorg_libpthread_stubs_jll"] -git-tree-sha1 = "b4bfde5d5b652e22b9c790ad00af08b6d042b97d" -uuid = "c7cfdc94-dc32-55de-ac96-5a1b8d977c5b" -version = "1.15.0+0" - -[[deps.Xorg_xtrans_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl"] -git-tree-sha1 = "e92a1a012a10506618f10b7047e478403a046c77" -uuid = "c5fb5394-a638-5e4d-96e5-b29de1b5cf10" -version = "1.5.0+0" - -[[deps.Zlib_jll]] -deps = ["Libdl"] -uuid = "83775a58-1f1d-513f-b197-d71354ab007a" -version = "1.2.13+0" - -[[deps.isoband_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "51b5eeb3f98367157a7a12a1fb0aa5328946c03c" -uuid = "9a68df92-36a6-505f-a73e-abb412b6bfb4" -version = "0.2.3+0" - -[[deps.libaom_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "3a2ea60308f0996d26f1e5354e10c24e9ef905d4" -uuid = "a4ae2306-e953-59d6-aa16-d00cac43593b" -version = "3.4.0+0" - -[[deps.libass_jll]] -deps = ["Artifacts", "Bzip2_jll", "FreeType2_jll", "FriBidi_jll", "HarfBuzz_jll", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] -git-tree-sha1 = "5982a94fcba20f02f42ace44b9894ee2b140fe47" -uuid = "0ac62f75-1d6f-5e53-bd7c-93b484bb37c0" -version = "0.15.1+0" - -[[deps.libblastrampoline_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" -version = "5.8.0+0" - -[[deps.libfdk_aac_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "daacc84a041563f965be61859a36e17c4e4fcd55" -uuid = "f638f0a6-7fb0-5443-88ba-1cc74229b280" -version = "2.0.2+0" - -[[deps.libpng_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg", "Zlib_jll"] -git-tree-sha1 = "94d180a6d2b5e55e447e2d27a29ed04fe79eb30c" -uuid = "b53b4c65-9356-5827-b1ea-8c7a1a84506f" -version = "1.6.38+0" - -[[deps.libsixel_jll]] -deps = ["Artifacts", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Pkg", "libpng_jll"] -git-tree-sha1 = "d4f63314c8aa1e48cd22aa0c17ed76cd1ae48c3c" -uuid = "075b6546-f08a-558a-be8f-8157d0f608a5" -version = "1.10.3+0" - -[[deps.libvorbis_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Ogg_jll", "Pkg"] -git-tree-sha1 = "b910cb81ef3fe6e78bf6acee440bda86fd6ae00c" -uuid = "f27f6e37-5d2b-51aa-960f-b287f2bc3b7a" -version = "1.3.7+1" - -[[deps.nghttp2_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" -version = "1.48.0+0" - -[[deps.p7zip_jll]] -deps = ["Artifacts", "Libdl"] -uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" -version = "17.4.0+0" - -[[deps.x264_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "4fea590b89e6ec504593146bf8b988b2c00922b2" -uuid = "1270edf5-f2f9-52d2-97e9-ab00b5d0237a" -version = "2021.5.5+0" - -[[deps.x265_jll]] -deps = ["Artifacts", "JLLWrappers", "Libdl", "Pkg"] -git-tree-sha1 = "ee567a171cce03570d77ad3a43e90218e38937a9" -uuid = "dfaa095f-4041-5dcd-9319-2fabd8486b76" -version = "3.5.0+0" diff --git a/docs/Project.toml b/docs/Project.toml index 1ffd37c6..eb3065ec 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -5,7 +5,6 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5" Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" @@ -15,23 +14,20 @@ MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91" NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SQLite = "0aa819cd-b072-5ff4-a722-6bc24af294d9" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" TMLE = "8afdd2fb-6e73-43df-8b62-b1650cd9c8cf" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -MLJXGBoostInterface = "0.3.8" CSV = "0.10.11" CairoMakie = "0.10.4" DataFrames = "1.5" Distributions = "0.25" Documenter = "1" -EvoTrees = "0.16" Literate = "2.13" MLJ = "0.19.1" MLJLinearModels = "0.9.1" +MLJXGBoostInterface = "0.3.8" NearestNeighborModels = "0.2" -SQLite = "1.4" StableRNGs = "1.0" diff --git a/src/TMLE.jl b/src/TMLE.jl index ac236d5e..b3ae6678 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -20,6 +20,7 @@ using Random import AbstractDifferentiation as AD using Graphs using MetaGraphsNext +using Configurations # ############################################################################# # EXPORTS @@ -37,6 +38,7 @@ export compose export TreatmentTransformer, with_encoder, encoder export BackdoorAdjustment, identify export last_fluctuation_epsilon +export estimands_from_yaml, estimands_to_yaml # ############################################################################# # INCLUDES @@ -56,8 +58,7 @@ include("counterfactual_mean_based/estimators.jl") include("counterfactual_mean_based/clever_covariate.jl") include("counterfactual_mean_based/gradient.jl") include("counterfactual_mean_based/adjustment.jl") - - +include("counterfactual_mean_based/configurations.jl") # ############################################################################# # PRECOMPILATION WORKLOAD diff --git a/src/counterfactual_mean_based/configurations.jl b/src/counterfactual_mean_based/configurations.jl new file mode 100644 index 00000000..a621e142 --- /dev/null +++ b/src/counterfactual_mean_based/configurations.jl @@ -0,0 +1,37 @@ +Configurations.from_dict(::Type{<:StatisticalCMCompositeEstimand}, ::Type{Symbol}, s) = Symbol(s) +Configurations.from_dict(::Type{<:StatisticalCMCompositeEstimand}, ::Type{<:Tuple}, s) = eval(Meta.parse(s)) +Configurations.from_dict(::Type{<:StatisticalCMCompositeEstimand}, ::Type{NamedTuple}, s) = eval(Meta.parse(s)) + +function estimand_to_dict(Ψ::StatisticalCMCompositeEstimand) + d = to_dict(Ψ, YAMLStyle) + d["type"] = typeof(Ψ) + return d +end + +""" + estimands_from_yaml(filepath) + +Read estimands from a YAML file. +""" +function estimands_from_yaml(filepath) + yaml_dict = YAML.load_file(filepath; dicttype=Dict{String, Any}) + estimands_dicts = yaml_dict["Estimands"] + nparams = size(estimands_dicts, 1) + estimands = Vector(undef, nparams) + for index in 1:nparams + d = estimands_dicts[index] + type = pop!(d, "type") + estimands[index] = from_dict(eval(Meta.parse(type)), d) + end + return estimands +end + +""" + estimands_to_yaml(filepath, estimands) + +Write estimands to a YAML file. +""" +function estimands_to_yaml(filepath, estimands) + d = Dict("Estimands" => [estimand_to_dict(Ψ) for Ψ in estimands]) + YAML.write_file(filepath, d) +end \ No newline at end of file diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index b9eaa5ef..06390caf 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -41,7 +41,7 @@ for (estimand, (formula,)) ∈ ESTIMANDS_DOCS statistical_estimand = Symbol(:Statistical, estimand) ex = quote # Causal Estimand - struct $(causal_estimand) <: Estimand + @option struct $(causal_estimand) <: Estimand outcome::Symbol treatment_values::NamedTuple @@ -52,7 +52,7 @@ for (estimand, (formula,)) ∈ ESTIMANDS_DOCS end end # Statistical Estimand - struct $(statistical_estimand) <: Estimand + @option struct $(statistical_estimand) <: Estimand outcome::Symbol treatment_values::NamedTuple treatment_confounders::NamedTuple diff --git a/test/counterfactual_mean_based/configurations.jl b/test/counterfactual_mean_based/configurations.jl new file mode 100644 index 00000000..10b8b8ef --- /dev/null +++ b/test/counterfactual_mean_based/configurations.jl @@ -0,0 +1,53 @@ +module TestConfigurations + +using Test +using TMLE +using Serialization + +@testset "Test configurations" begin + estimand_file = "estimands_sample.yaml" + estimands = [ + IATE( + outcome=:Y1, + treatment_values= ( + T1 = (case = 1, control = 0), + T2 = (case = 1, control = 0) + ), + treatment_confounders=( + T1 = [:W11, :W12], + T2 = [:W21, :W22], + ), + outcome_extra_covariates=[:C1] + ), + ATE(; + outcome=:Y3, + treatment_values = ( + T1 = (case = 1, control = 0), + T3 = (case = "AC", control = "CC") + ), + treatment_confounders=( + T1 = [:W], + T3 = [:W] + ), + ), + CM(; + outcome=:Y3, + treatment_values=(T3 = "AC", T1 = "CC"), + treatment_confounders=(T3 = [:W], T1 = [:W]) + ) + ] + # Test YAML Serialization + estimands_to_yaml(estimand_file, estimands) + loaded_estimands = estimands_from_yaml(estimand_file) + @test loaded_estimands == estimands + rm(estimand_file) + # Test basic Serialization + serialize(estimand_file, estimands) + loaded_estimands = deserialize(estimand_file) + @test loaded_estimands == estimands + rm(estimand_file) +end + +end + +true \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 63439c71..67828747 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,4 +19,5 @@ using Test @test include("counterfactual_mean_based/double_robustness_iate.jl") @test include("counterfactual_mean_based/3points_interactions.jl") @test include("counterfactual_mean_based/adjustment.jl") + @test include("counterfactual_mean_based/configurations.jl") end \ No newline at end of file From b3170bcf805f91893e944da688c5d21ccd3449db Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 6 Nov 2023 16:16:29 +0000 Subject: [PATCH 102/151] add configuration --- Project.toml | 4 - src/TMLE.jl | 8 +- src/configuration.jl | 36 ++++++ src/counterfactual_mean_based/adjustment.jl | 7 +- .../configurations.jl | 37 ------ src/counterfactual_mean_based/estimands.jl | 63 ++++++++- src/scm.jl | 30 +++-- test/Project.toml | 2 + test/configuration.jl | 120 ++++++++++++++++++ test/counterfactual_mean_based/adjustment.jl | 11 +- .../configurations.jl | 53 -------- test/counterfactual_mean_based/estimands.jl | 74 +++++++++++ test/runtests.jl | 2 +- test/scm.jl | 57 +++++++-- 14 files changed, 379 insertions(+), 125 deletions(-) create mode 100644 src/configuration.jl delete mode 100644 src/counterfactual_mean_based/configurations.jl create mode 100644 test/configuration.jl delete mode 100644 test/counterfactual_mean_based/configurations.jl diff --git a/Project.toml b/Project.toml index 5be2577f..480e1a31 100644 --- a/Project.toml +++ b/Project.toml @@ -6,7 +6,6 @@ version = "0.12.0" [deps] AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" -Configurations = "5218b696-f38b-4ac9-8b61-a12ec717816d" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -23,7 +22,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" -YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] @@ -43,7 +41,5 @@ PrecompileTools = "1.1.1" PrettyTables = "2.2" TableOperations = "1.2" Tables = "1.6" -YAML = "0.4" Zygote = "0.6" -Configurations = "0.17.6" julia = "1.6, 1.7, 1" diff --git a/src/TMLE.jl b/src/TMLE.jl index b3ae6678..bc98fd28 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -13,20 +13,18 @@ using Statistics using Distributions using Zygote using LogExpFunctions -using YAML using PrecompileTools using PrettyTables using Random import AbstractDifferentiation as AD using Graphs using MetaGraphsNext -using Configurations # ############################################################################# # EXPORTS # ############################################################################# -export SCM, StaticSCM, add_equations!, add_equation!, parents +export SCM, StaticSCM, add_equations!, add_equation!, parents, vertices export CounterfactualMean, CM export AverageTreatmentEffect, ATE export InteractionAverageTreatmentEffect, IATE @@ -39,6 +37,7 @@ export TreatmentTransformer, with_encoder, encoder export BackdoorAdjustment, identify export last_fluctuation_epsilon export estimands_from_yaml, estimands_to_yaml +export to_dict, from_dict!, Configuration # ############################################################################# # INCLUDES @@ -58,7 +57,8 @@ include("counterfactual_mean_based/estimators.jl") include("counterfactual_mean_based/clever_covariate.jl") include("counterfactual_mean_based/gradient.jl") include("counterfactual_mean_based/adjustment.jl") -include("counterfactual_mean_based/configurations.jl") + +include("configuration.jl") # ############################################################################# # PRECOMPILATION WORKLOAD diff --git a/src/configuration.jl b/src/configuration.jl new file mode 100644 index 00000000..d0a67c13 --- /dev/null +++ b/src/configuration.jl @@ -0,0 +1,36 @@ +struct Configuration + estimands::AbstractVector{<:Estimand} + scm::Union{Nothing, SCM} + adjustment::Union{Nothing, <:AdjustmentMethod} +end + +Configuration(;estimands, scm=nothing, adjustment=nothing) = Configuration(estimands, scm, adjustment) + +function to_dict(configuration::Configuration) + config_dict = Dict( + :type => "Configuration", + :estimands => [to_dict(Ψ) for Ψ in configuration.estimands], + ) + if configuration.scm !== nothing + config_dict[:scm] = to_dict(configuration.scm) + end + if configuration.adjustment !== nothing + config_dict[:adjustment] = to_dict(configuration.adjustment) + end + return config_dict +end + +from_dict!(x) = x + +from_dict!(v::AbstractVector) = [from_dict!(x) for x in v] + +""" + from_dict!(d::Dict) + +Converts a dictionary to a TMLE struct. +""" +function from_dict!(d::Dict{T, Any}) where T + haskey(d, T(:type)) || return d + constructor = eval(Meta.parse(pop!(d, :type))) + return constructor(;[key => from_dict!(val) for (key, val) in d]...) +end \ No newline at end of file diff --git a/src/counterfactual_mean_based/adjustment.jl b/src/counterfactual_mean_based/adjustment.jl index 88b0bec8..d8a42c43 100644 --- a/src/counterfactual_mean_based/adjustment.jl +++ b/src/counterfactual_mean_based/adjustment.jl @@ -38,4 +38,9 @@ function identify(method::BackdoorAdjustment, causal_estimand::T, scm::SCM) wher treatment_confounders = treatment_confounders, outcome_extra_covariates = method.outcome_extra_covariates ) -end \ No newline at end of file +end + +to_dict(adjustment::BackdoorAdjustment) = Dict( + :type => "BackdoorAdjustment", + :outcome_extra_covariates => collect(adjustment.outcome_extra_covariates) +) \ No newline at end of file diff --git a/src/counterfactual_mean_based/configurations.jl b/src/counterfactual_mean_based/configurations.jl deleted file mode 100644 index a621e142..00000000 --- a/src/counterfactual_mean_based/configurations.jl +++ /dev/null @@ -1,37 +0,0 @@ -Configurations.from_dict(::Type{<:StatisticalCMCompositeEstimand}, ::Type{Symbol}, s) = Symbol(s) -Configurations.from_dict(::Type{<:StatisticalCMCompositeEstimand}, ::Type{<:Tuple}, s) = eval(Meta.parse(s)) -Configurations.from_dict(::Type{<:StatisticalCMCompositeEstimand}, ::Type{NamedTuple}, s) = eval(Meta.parse(s)) - -function estimand_to_dict(Ψ::StatisticalCMCompositeEstimand) - d = to_dict(Ψ, YAMLStyle) - d["type"] = typeof(Ψ) - return d -end - -""" - estimands_from_yaml(filepath) - -Read estimands from a YAML file. -""" -function estimands_from_yaml(filepath) - yaml_dict = YAML.load_file(filepath; dicttype=Dict{String, Any}) - estimands_dicts = yaml_dict["Estimands"] - nparams = size(estimands_dicts, 1) - estimands = Vector(undef, nparams) - for index in 1:nparams - d = estimands_dicts[index] - type = pop!(d, "type") - estimands[index] = from_dict(eval(Meta.parse(type)), d) - end - return estimands -end - -""" - estimands_to_yaml(filepath, estimands) - -Write estimands to a YAML file. -""" -function estimands_to_yaml(filepath, estimands) - d = Dict("Estimands" => [estimand_to_dict(Ψ) for Ψ in estimands]) - YAML.write_file(filepath, d) -end \ No newline at end of file diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index 06390caf..2cb47abe 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -41,18 +41,18 @@ for (estimand, (formula,)) ∈ ESTIMANDS_DOCS statistical_estimand = Symbol(:Statistical, estimand) ex = quote # Causal Estimand - @option struct $(causal_estimand) <: Estimand + struct $(causal_estimand) <: Estimand outcome::Symbol treatment_values::NamedTuple function $(causal_estimand)(outcome, treatment_values) outcome = Symbol(outcome) - treatment_variables = Tuple(keys(treatment_values)) + treatment_values = get_treatment_specs(treatment_values) return new(outcome, treatment_values) end end # Statistical Estimand - @option struct $(statistical_estimand) <: Estimand + struct $(statistical_estimand) <: Estimand outcome::Symbol treatment_values::NamedTuple treatment_confounders::NamedTuple @@ -60,6 +60,7 @@ for (estimand, (formula,)) ∈ ESTIMANDS_DOCS function $(statistical_estimand)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) outcome = Symbol(outcome) + treatment_values = get_treatment_specs(treatment_values) treatment_variables = Tuple(keys(treatment_values)) treatment_confounders = NamedTuple{treatment_variables}([unique_sorted_tuple(treatment_confounders[T]) for T ∈ treatment_variables]) outcome_extra_covariates = unique_sorted_tuple(outcome_extra_covariates) @@ -68,6 +69,12 @@ for (estimand, (formula,)) ∈ ESTIMANDS_DOCS end # Constructors + $(causal_estimand)(;outcome, treatment_values) = $(causal_estimand)(outcome, treatment_values) + + $(statistical_estimand)(;outcome, treatment_values, treatment_confounders, outcome_extra_covariates) = + $(statistical_estimand)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) + + # Short Name Constructors $(estimand)(outcome, treatment_values) = $(causal_estimand)(outcome, treatment_values) $(estimand)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates) = @@ -126,4 +133,54 @@ function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: StatisticalCMCo "\nTreatment: ", Ψ.treatment_values ) println(io, param_string) +end + +function treatment_specs_to_dict(treatment_values::NamedTuple{T, <:Tuple{Vararg{<:NamedTuple}}}) where T + Dict(key => Dict(pairs(vals)) for (key, vals) in pairs(treatment_values)) +end + +treatment_specs_to_dict(treatment_values::NamedTuple) = Dict(pairs(treatment_values)) + +treatment_values(d::AbstractDict) = (;d...) +treatment_values(d) = d + +function treatment_specs_from_dict(dict::AbstractDict) + treatment_variables = keys(dict) + return NamedTuple{Tuple(treatment_variables)}([treatment_values(val) for val in values(dict)]) +end + +get_treatment_specs(treatment_values::NamedTuple) = treatment_values +get_treatment_specs(treatment_values::AbstractDict) = treatment_specs_from_dict(treatment_values) + +constructorname(T; prefix="TMLE.Causal") = replace(string(T), prefix => "") + +treatment_confounders_to_dict(treatment_confounders::NamedTuple) = + Dict(key => collect(vals) for (key, vals) in pairs(treatment_confounders)) + +""" + to_dict(Ψ::T) where T <: CausalCMCompositeEstimands + +Converts Ψ to a dictionary that can be serialized. +""" +function to_dict(Ψ::T) where T <: CausalCMCompositeEstimands + return Dict( + :type => constructorname(T; prefix="TMLE.Causal"), + :outcome => Ψ.outcome, + :treatment_values => treatment_specs_to_dict(Ψ.treatment_values) + ) +end + +""" + to_dict(Ψ::T) where T <: StatisticalCMCompositeEstimand + +Converts Ψ to a dictionary that can be serialized. +""" +function to_dict(Ψ::T) where T <: StatisticalCMCompositeEstimand + return Dict( + :type => constructorname(T; prefix="TMLE.Statistical"), + :outcome => Ψ.outcome, + :treatment_values => treatment_specs_to_dict(Ψ.treatment_values), + :treatment_confounders => treatment_confounders_to_dict(Ψ.treatment_confounders), + :outcome_extra_covariates => collect(Ψ.outcome_extra_covariates) + ) end \ No newline at end of file diff --git a/src/scm.jl b/src/scm.jl index 9820529d..dcad8a6e 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -9,7 +9,7 @@ struct SCM graph::MetaGraph end -function SCM(equations...) +function SCM(equations) graph = MetaGraph( SimpleDiGraph(); label_type=Symbol, @@ -21,6 +21,8 @@ function SCM(equations...) return scm end +SCM(;equations=()) = SCM(equations) + function string_repr(scm::SCM) string_rep = "SCM\n---" for (vertex_id, vertex_label) in scm.graph.vertex_labels @@ -38,6 +40,8 @@ end Base.show(io::IO, ::MIME"text/plain", scm::SCM) = println(io, string_repr(scm)) +split_outcome_parent_pair(outcome_parents_pair::Pair) = outcome_parents_pair +split_outcome_parent_pair(outcome_parents_pair::Dict{T, Any}) where T = outcome_parents_pair[T(:outcome)], outcome_parents_pair[T(:parents)] function add_equations!(scm::SCM, equations...) for outcome_parents_pair in equations @@ -46,7 +50,7 @@ function add_equations!(scm::SCM, equations...) end function add_equation!(scm::SCM, outcome_parents_pair) - outcome, parents = outcome_parents_pair + outcome, parents = split_outcome_parent_pair(outcome_parents_pair) outcome_symbol = Symbol(outcome) add_vertex!(scm.graph, outcome_symbol) for parent in parents @@ -56,11 +60,13 @@ function add_equation!(scm::SCM, outcome_parents_pair) end end -function parents(scm, label) +function parents(scm::SCM, label) code, _ = scm.graph.vertex_properties[label] - return [scm.graph.vertex_labels[parent_code] for parent_code in scm.graph.graph.badjlist[code]] + return Set((scm.graph.vertex_labels[parent_code] for parent_code in scm.graph.graph.badjlist[code])) end +vertices(scm::SCM) = collect(keys(scm.graph.vertex_properties)) + """ A plate Structural Causal Model where: @@ -71,11 +77,17 @@ A plate Structural Causal Model where: StaticSCM([:Y], [:T₁, :T₂], [:W₁, :W₂, :W₃]; outcome_extra_covariates=[:C]) """ -function StaticSCM(outcomes, treatments, confounders; outcome_extra_covariates=()) - outcome_equations = (outcome => unique(vcat(treatments, confounders, outcome_extra_covariates)) for outcome in outcomes) +function StaticSCM(outcomes, treatments, confounders) + outcome_equations = (outcome => unique(vcat(treatments, confounders)) for outcome in outcomes) treatment_equations = (treatment => unique(confounders) for treatment in treatments) - return SCM(outcome_equations..., treatment_equations...) + return SCM(equations=(outcome_equations..., treatment_equations...)) end -StaticSCM(;outcomes, treatments, confounders, outcome_extra_covariates=()) = - StaticSCM(outcomes, treatments, confounders; outcome_extra_covariates=outcome_extra_covariates) +StaticSCM(;outcomes, treatments, confounders) = + StaticSCM(outcomes, treatments, confounders) + + +to_dict(scm::SCM) = Dict( + :type => "SCM", + :equations => [Dict(:outcome => label, :parents => collect(parents(scm, label))) for label ∈ vertices(scm)] +) \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index 065a040d..eb966128 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -18,6 +18,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" [compat] CSV = "0.10" @@ -25,3 +26,4 @@ DataFrames = "1.5" MLJLinearModels = "0.9" StableRNGs = "1.0" StatsBase = "0.33" +YAML = "0.4.9" diff --git a/test/configuration.jl b/test/configuration.jl new file mode 100644 index 00000000..3e9b3c8b --- /dev/null +++ b/test/configuration.jl @@ -0,0 +1,120 @@ +module TestConfiguration + +using Test +using TMLE +using YAML + +@testset "Test Configurations" begin + outfilename = "test_configuration.yml" + estimands = [ + IATE( + outcome=:Y1, + treatment_values= ( + T1 = (case = 1, control = 0), + T2 = (case = 1, control = 0) + ), + treatment_confounders=( + T1 = [:W11, :W12], + T2 = [:W21, :W22], + ), + outcome_extra_covariates=[:C1] + ), + ATE(; + outcome=:Y3, + treatment_values = ( + T1 = (case = 1, control = 0), + T3 = (case = "AC", control = "CC") + ), + treatment_confounders=( + T1 = [:W], + T3 = [:W] + ), + ), + CM(; + outcome=:Y3, + treatment_values=(T3 = "AC", T1 = "CC"), + treatment_confounders=(T3 = [:W], T1 = [:W]) + ) + ] + configuration = Configuration( + estimands=estimands + ) + config_dict = to_dict(configuration) + YAML.write_file(outfilename, config_dict) + dict_from_yaml = YAML.load_file(outfilename, dicttype=Dict{Symbol,Any}) + reconstructed_configuration = from_dict!(dict_from_yaml) + @test reconstructed_configuration.scm === nothing + @test reconstructed_configuration.adjustment === nothing + for index in 1:length(estimands) + estimand = configuration.estimands[index] + reconstructed_estimand = reconstructed_configuration.estimands[index] + @test estimand.outcome == reconstructed_estimand.outcome + for treatment in keys(estimand.treatment_values) + @test estimand.treatment_values[treatment] == reconstructed_estimand.treatment_values[treatment] + @test estimand.treatment_confounders[treatment] == reconstructed_estimand.treatment_confounders[treatment] + @test estimand.outcome_extra_covariates == reconstructed_estimand.outcome_extra_covariates + end + end + rm(outfilename) + + # With a StaticSCM, some Causal estimands and an Adjustment Method + estimands = [ + IATE( + outcome=:Y1, + treatment_values= ( + T1 = (case = 1, control = 0), + T2 = (case = 1, control = 0) + ) + ), + ATE(; + outcome=:Y3, + treatment_values = ( + T1 = (case = 1, control = 0), + T3 = (case = "AC", control = "CC") + ), + treatment_confounders=( + T1 = [:W], + T3 = [:W] + ), + ), + ] + + configuration = Configuration( + estimands=estimands, + adjustment=BackdoorAdjustment([:C1, :C2]) + ) + config_dict = to_dict(configuration) + config_dict[:scm] = Dict( + :type => "StaticSCM", + :confounders => [:W], + :outcomes => [:Y1], + :treatments => [:T1, :T2] + ) + YAML.write_file(outfilename, config_dict) + config_from_yaml = from_dict!(YAML.load_file(outfilename, dicttype=Dict{Symbol,Any})) + + scm = config_from_yaml.scm + adjustment = config_from_yaml.adjustment + # Estimand 1 + estimand₁ = config_from_yaml.estimands[1] + statistical_estimand₁ = identify(adjustment, estimand₁, scm) + @test statistical_estimand₁.outcome == :Y1 + @test statistical_estimand₁.treatment_values.T1 == (case=1, control=0) + @test statistical_estimand₁.treatment_values.T2 == (case=1, control=0) + @test statistical_estimand₁.treatment_confounders.T2 == (:W,) + @test statistical_estimand₁.treatment_confounders.T1 == (:W,) + @test statistical_estimand₁.outcome_extra_covariates == (:C1, :C2) + # Estimand 2 + estimand₂ = config_from_yaml.estimands[2] + @test estimand₂.outcome == :Y3 + @test estimand₂.treatment_values.T1 == (case = 1, control = 0) + @test estimand₂.treatment_values.T3 == (case = "AC", control = "CC") + @test estimand₂.treatment_confounders.T1 == (:W,) + @test estimand₂.treatment_confounders.T3 == (:W,) + + rm(outfilename) +end + +end + +true \ No newline at end of file diff --git a/test/counterfactual_mean_based/adjustment.jl b/test/counterfactual_mean_based/adjustment.jl index 688cbd29..29638d98 100644 --- a/test/counterfactual_mean_based/adjustment.jl +++ b/test/counterfactual_mean_based/adjustment.jl @@ -8,7 +8,6 @@ using TMLE outcomes = [:Y₁, :Y₂], treatments = [:T₁, :T₂], confounders = [:W₁, :W₂], - outcome_extra_covariates = [:C] ) causal_estimands = [ CM(outcome=:Y₁, treatment_values=(T₁=1,)), @@ -30,6 +29,16 @@ using TMLE @test statistical_estimands[4].treatment_confounders == (T₁=(:W₁, :W₂), T₂=(:W₁, :W₂)) end +@testset "Test to_dict" begin + adjustment = BackdoorAdjustment(outcome_extra_covariates=[:C]) + adjustment_dict = to_dict(adjustment) + @test adjustment_dict == Dict( + :outcome_extra_covariates => [:C], + :type => "BackdoorAdjustment" + ) + @test from_dict!(adjustment_dict) == adjustment +end + end true \ No newline at end of file diff --git a/test/counterfactual_mean_based/configurations.jl b/test/counterfactual_mean_based/configurations.jl deleted file mode 100644 index 10b8b8ef..00000000 --- a/test/counterfactual_mean_based/configurations.jl +++ /dev/null @@ -1,53 +0,0 @@ -module TestConfigurations - -using Test -using TMLE -using Serialization - -@testset "Test configurations" begin - estimand_file = "estimands_sample.yaml" - estimands = [ - IATE( - outcome=:Y1, - treatment_values= ( - T1 = (case = 1, control = 0), - T2 = (case = 1, control = 0) - ), - treatment_confounders=( - T1 = [:W11, :W12], - T2 = [:W21, :W22], - ), - outcome_extra_covariates=[:C1] - ), - ATE(; - outcome=:Y3, - treatment_values = ( - T1 = (case = 1, control = 0), - T3 = (case = "AC", control = "CC") - ), - treatment_confounders=( - T1 = [:W], - T3 = [:W] - ), - ), - CM(; - outcome=:Y3, - treatment_values=(T3 = "AC", T1 = "CC"), - treatment_confounders=(T3 = [:W], T1 = [:W]) - ) - ] - # Test YAML Serialization - estimands_to_yaml(estimand_file, estimands) - loaded_estimands = estimands_from_yaml(estimand_file) - @test loaded_estimands == estimands - rm(estimand_file) - # Test basic Serialization - serialize(estimand_file, estimands) - loaded_estimands = deserialize(estimand_file) - @test loaded_estimands == estimands - rm(estimand_file) -end - -end - -true \ No newline at end of file diff --git a/test/counterfactual_mean_based/estimands.jl b/test/counterfactual_mean_based/estimands.jl index 8d0ec93d..4392a7c1 100644 --- a/test/counterfactual_mean_based/estimands.jl +++ b/test/counterfactual_mean_based/estimands.jl @@ -86,6 +86,80 @@ end end end +@testset "Test dictionary conversion" begin + # Causal CM + Ψ = CM( + outcome=:y, + treatment_values = (T₁=1, T₂="AC") + ) + d = to_dict(Ψ) + @test d == Dict( + :type => "CM", + :treatment_values => Dict(:T₁=>1, :T₂=>"AC"), + :outcome => :y + ) + Ψreconstructed = from_dict!(d) + @test Ψreconstructed == Ψ + # Causal ATE + Ψ = ATE( + outcome=:y, + treatment_values=( + T₁=(case="A", control="B"), + T₂=(case=1, control=0), + T₃=(case="C", control="D") + ), + ) + d = to_dict(Ψ) + @test d == Dict( + :type => "ATE", + :treatment_values => Dict( + :T₁ => Dict(:case => "A", :control => "B"), + :T₂ => Dict(:case => 1, :control => 0), + :T₃ => Dict(:control => "D", :case => "C") + ), + :outcome => :y + ) + Ψreconstructed = from_dict!(d) + @test Ψreconstructed.outcome == Ψ.outcome + @test Ψreconstructed.treatment_values.T₁ == Ψ.treatment_values.T₁ + @test Ψreconstructed.treatment_values.T₂ == Ψ.treatment_values.T₂ + @test Ψreconstructed.treatment_values.T₃ == Ψ.treatment_values.T₃ + + # Statistical CM + Ψ = CM( + outcome=:y, + treatment_values = (T₁=1, T₂="AC"), + treatment_confounders = (T₁=[:W₁₁, :W₁₂], T₂=[:W₁₂, :W₂₂]) + ) + d = to_dict(Ψ) + @test d == Dict( + :outcome_extra_covariates => [], + :type => "CM", + :treatment_values => Dict(:T₁=>1, :T₂=>"AC"), + :outcome => :y, + :treatment_confounders => Dict(:T₁=>[:W₁₁, :W₁₂], :T₂=>[:W₁₂, :W₂₂]) + ) + Ψreconstructed = from_dict!(d) + @test Ψreconstructed == Ψ + + # Statistical IATE + Ψ = IATE( + outcome=:y, + treatment_values = (T₁=1, T₂="AC"), + treatment_confounders = (T₁=[:W₁₁, :W₁₂], T₂=[:W₁₂, :W₂₂]), + outcome_extra_covariates=[:C] + ) + d = to_dict(Ψ) + @test d == Dict( + :outcome_extra_covariates => [:C], + :type => "IATE", + :treatment_values => Dict(:T₁=>1, :T₂=>"AC"), + :outcome => :y, + :treatment_confounders => Dict(:T₁=>[:W₁₁, :W₁₂], :T₂=>[:W₁₂, :W₂₂]) + ) + Ψreconstructed = from_dict!(d) + @test Ψreconstructed == Ψ +end end true \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 67828747..3d11f7ef 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ using Test @test include("composition.jl") @test include("treatment_transformer.jl") @test include("scm.jl") + @test include("configuration.jl") @test include("counterfactual_mean_based/estimands.jl") @test include("counterfactual_mean_based/clever_covariate.jl") @@ -19,5 +20,4 @@ using Test @test include("counterfactual_mean_based/double_robustness_iate.jl") @test include("counterfactual_mean_based/3points_interactions.jl") @test include("counterfactual_mean_based/adjustment.jl") - @test include("counterfactual_mean_based/configurations.jl") end \ No newline at end of file diff --git a/test/scm.jl b/test/scm.jl index 34c08bfd..65dfc333 100644 --- a/test/scm.jl +++ b/test/scm.jl @@ -12,21 +12,44 @@ using TMLE :T₁ => [:W₁] ) add_equation!(scm, :T₂ => [:W₂]) - @test parents(scm, :T₁) == [:W₁] - @test parents(scm, :T₂) == [:W₂] - @test parents(scm, :Y) == [:T₁, :T₂, :W₁, :W₂, :C] - @test parents(scm, :C) == [] + @test parents(scm, :T₁) == Set([:W₁]) + @test parents(scm, :T₂) == Set([:W₂]) + @test parents(scm, :Y) == Set([:T₁, :T₂, :W₁, :W₂, :C]) + @test parents(scm, :C) == Set([]) + expected_vertices = Set([:Y, :T₁, :T₂, :W₁, :W₂, :C]) + @test Set(vertices(scm)) == expected_vertices + + scm_dict = to_dict(scm) + @test scm_dict == Dict( + :type => "SCM", + :equations => [ + Dict(:parents => Any[], :outcome => :C), + Dict(:parents => Any[], :outcome => :W₁), + Dict(:parents => [:C, :W₁, :T₁, :W₂, :T₂], :outcome => :Y), + Dict(:parents => [:W₁], :outcome => :T₁), + Dict(:parents => Any[], :outcome => :W₂), + Dict(:parents => [:W₂], :outcome => :T₂) + ] + ) + reconstructed_scm = from_dict!(scm_dict) + for vertexlabel ∈ expected_vertices + @test parents(reconstructed_scm, vertexlabel) == parents(scm, vertexlabel) + end + # Another SCM - scm = SCM( + scm = SCM([ :Y₁ => [:T₁, :W₁], :T₁ => [:W₁], :Y₂ => [:T₁, :T₂, :W₁, :W₂, :C], :T₂ => [:W₂] + ] ) - @test parents(scm, :Y₁) == [:T₁, :W₁] - @test parents(scm, :T₁) == [:W₁] - @test parents(scm, :Y₂) == [:T₁, :W₁, :T₂, :W₂, :C] - @test parents(scm, :T₂) == [:W₂] + @test parents(scm, :Y₁) == Set([:T₁, :W₁]) + @test parents(scm, :T₁) == Set([:W₁]) + @test parents(scm, :Y₂) == Set([:T₁, :W₁, :T₂, :W₂, :C]) + @test parents(scm, :T₂) == Set([:W₂]) + @test Set(vertices(scm)) == Set([:Y₁, :Y₂, :T₁, :T₂, :W₁, :W₂, :C]) + end @testset "Test StaticSCM" begin @@ -34,10 +57,20 @@ end outcomes = [:Y₁, :Y₂], treatments = [:T₁, :T₂, :T₃], confounders = [:W₁, :W₂], - outcome_extra_covariates = [:C] ) - @test parents(scm, :Y₁) == parents(scm, :Y₂) == [:T₁, :T₂, :T₃, :W₁, :W₂, :C] - @test parents(scm, :T₁) == parents(scm, :T₂) == parents(scm, :T₃) == [:W₁, :W₂] + @test parents(scm, :Y₁) == parents(scm, :Y₂) == Set([:T₁, :T₂, :T₃, :W₁, :W₂]) + @test parents(scm, :T₁) == parents(scm, :T₂) == parents(scm, :T₃) == Set([:W₁, :W₂]) + + reconstructed_scm = from_dict!(Dict( + :type => "StaticSCM", + :outcomes => [:Y₁, :Y₂], + :treatments => [:T₁, :T₂, :T₃], + :confounders => [:W₁, :W₂] + )) + for vertexlabel ∈ [:Y₁, :Y₂, :T₁, :T₂, :T₃, :W₁, :W₂] + @test parents(reconstructed_scm, vertexlabel) == parents(scm, vertexlabel) + end + end end From 78b94ec40537b6a23d7f617398cff0ae91529e08 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 6 Nov 2023 16:30:21 +0000 Subject: [PATCH 103/151] update MLJBase dep and fix warnings --- Project.toml | 2 +- src/counterfactual_mean_based/estimands.jl | 2 +- src/scm.jl | 24 +++++++++++----------- test/Project.toml | 2 ++ test/helper_fns.jl | 1 + 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index 480e1a31..217649a1 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,7 @@ GLM = "1.8.2" Graphs = "1.8" HypothesisTests = "0.10, 0.11" LogExpFunctions = "0.3" -MLJBase = "0.19, 0.20, 0.21" +MLJBase = "1.0.1" MLJGLMInterface = "0.3.4" MLJModels = "0.15, 0.16" MetaGraphsNext = "0.6" diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index 2cb47abe..c2acae2e 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -135,7 +135,7 @@ function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: StatisticalCMCo println(io, param_string) end -function treatment_specs_to_dict(treatment_values::NamedTuple{T, <:Tuple{Vararg{<:NamedTuple}}}) where T +function treatment_specs_to_dict(treatment_values::NamedTuple{T, <:Tuple{Vararg{NamedTuple}}}) where T Dict(key => Dict(pairs(vals)) for (key, vals) in pairs(treatment_values)) end diff --git a/src/scm.jl b/src/scm.jl index dcad8a6e..9bdb393f 100644 --- a/src/scm.jl +++ b/src/scm.jl @@ -7,18 +7,18 @@ A SCM is simply a wrapper around a MetaGraph over a Directed Acyclic Graph. """ struct SCM graph::MetaGraph -end - -function SCM(equations) - graph = MetaGraph( - SimpleDiGraph(); - label_type=Symbol, - vertex_data_type=Nothing, - edge_data_type=Nothing - ) - scm = SCM(graph) - add_equations!(scm, equations...) - return scm + + function SCM(equations) + graph = MetaGraph( + SimpleDiGraph(); + label_type=Symbol, + vertex_data_type=Nothing, + edge_data_type=Nothing + ) + scm = new(graph) + add_equations!(scm, equations...) + return scm + end end SCM(;equations=()) = SCM(equations) diff --git a/test/Project.toml b/test/Project.toml index eb966128..f7b8c11d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,6 +13,7 @@ MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87" @@ -27,3 +28,4 @@ MLJLinearModels = "0.9" StableRNGs = "1.0" StatsBase = "0.33" YAML = "0.4.9" +StatisticalMeasures = "0.1.3" \ No newline at end of file diff --git a/test/helper_fns.jl b/test/helper_fns.jl index aaf3ab92..2ab0ae0a 100644 --- a/test/helper_fns.jl +++ b/test/helper_fns.jl @@ -2,6 +2,7 @@ using TMLE using Test using MLJBase using CategoricalArrays +using StatisticalMeasures """ The risk for continuous outcomes is the MSE From 71820a652d530c9d1058c5d2586c6b1f6fe1edd0 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 6 Nov 2023 16:36:38 +0000 Subject: [PATCH 104/151] up AD dep --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 217649a1..603007bb 100644 --- a/Project.toml +++ b/Project.toml @@ -25,7 +25,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -AbstractDifferentiation = "0.4, 0.5" +AbstractDifferentiation = "0.6.0" CategoricalArrays = "0.10" Distributions = "0.25" GLM = "1.8.2" From ce918ddb4a56c8868b92564e556429de511ade06 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 6 Nov 2023 16:44:43 +0000 Subject: [PATCH 105/151] up MLJ docs deps --- docs/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index eb3065ec..be5eafdb 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -26,7 +26,7 @@ DataFrames = "1.5" Distributions = "0.25" Documenter = "1" Literate = "2.13" -MLJ = "0.19.1" +MLJ = "0.20.1" MLJLinearModels = "0.9.1" MLJXGBoostInterface = "0.3.8" NearestNeighborModels = "0.2" From fac1bdad40d02ff721a52d233e2a3daf441c8fe7 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 6 Nov 2023 17:07:05 +0000 Subject: [PATCH 106/151] add some docs --- docs/src/user_guide/estimation.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/docs/src/user_guide/estimation.md b/docs/src/user_guide/estimation.md index 01d5d00a..9f78b1e5 100644 --- a/docs/src/user_guide/estimation.md +++ b/docs/src/user_guide/estimation.md @@ -113,6 +113,23 @@ nothing # hide Again, required nuisance functions are fitted and stored in the cache. +## CV-Estimation + +Both TMLE and OSE can be used with sample-splitting, which, for an additional computational cost, further reduces the assumptions we need to make regarding our data generating process ([see here](https://arxiv.org/abs/2203.06469)). Note that this sample-splitting procedure should not be confused with the sample-splitting happening in Super Learning. Using both CV-TMLE and Super-Learning will result in two nested sample-splitting loops. + +To leverage sample-splitting, simply specify a `resampling` strategy when building an estimator: + +```@example estimation +cvtmle = TMLEE(models, resampling=CV()) +cvresult₁, _ = cvtmle(Ψ₁, dataset); +``` + +Similarly, one could build CV-OSE: + +```julia +cvose = OSE(models, resampling=CV(nfolds=3)) +``` + ## Reusing the SCM Let's now see how the `cache` can be reused with a new estimand, say the Total Average Treatment Effect of both `T₁` and `T₂`. From 5fa052a8f564c502dd3432fb2f25a9e6249c9089 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 6 Nov 2023 21:57:22 +0000 Subject: [PATCH 107/151] update docs --- docs/src/user_guide/estimands.md | 6 +++--- docs/src/user_guide/estimation.md | 5 +++-- docs/src/user_guide/scm.md | 13 ++++++------- docs/src/walk_through.md | 3 ++- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/docs/src/user_guide/estimands.md b/docs/src/user_guide/estimands.md index 219b6f7c..2b336679 100644 --- a/docs/src/user_guide/estimands.md +++ b/docs/src/user_guide/estimands.md @@ -17,9 +17,9 @@ In what follows, ``P`` is a probability distribution generating an outcome ``Y`` ```@example estimands using TMLE scm = StaticSCM( - [:Y], - [:T₁, :T₂], - [:W]; + outcomes=[:Y], + treatments=[:T₁, :T₂], + confounders=[:W] ) ``` diff --git a/docs/src/user_guide/estimation.md b/docs/src/user_guide/estimation.md index 9f78b1e5..5c5e2a92 100644 --- a/docs/src/user_guide/estimation.md +++ b/docs/src/user_guide/estimation.md @@ -43,10 +43,11 @@ function make_dataset(;n=1000) ) end dataset = make_dataset(n=10000) -scm = SCM( +scm = SCM([ :Y => [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], :T₁ => [:W₁₁, :W₁₂], - :T₂ => [:W₂₁, :W₂₂], + :T₂ => [:W₂₁, :W₂₂] +] ) ``` diff --git a/docs/src/user_guide/scm.md b/docs/src/user_guide/scm.md index 641e580e..371d3c2f 100644 --- a/docs/src/user_guide/scm.md +++ b/docs/src/user_guide/scm.md @@ -32,11 +32,11 @@ add_equations!(scm, :T₁ => [:W₁₁, :W₁₂, :W], :T₂ => [:W₂₁, :W₂ Instead of constructing the `SCM` incrementally, one can provide all the specified equations at once: ```@example scm -scm = SCM( +scm = SCM([ :Y => [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :W, :C], :T₁ => [:W₁₁, :W₁₂, :W], - :T₂ => [:W₂₁, :W₂₂, :W], -) + :T₂ => [:W₂₁, :W₂₂, :W] +]) ``` ## Classic Structural Causal Models @@ -45,10 +45,9 @@ There are many cases where we are interested in estimating the causal effect of ```@example scm scm = StaticSCM( - [:Y₁, :Y₂], - [:T₁, :T₂], - [:W₁, :W₂]; - outcome_extra_covariates=[:C], + outcomes=[:Y₁, :Y₂], + treatments=[:T₁, :T₂], + confounders=[:W₁, :W₂]; ) ``` diff --git a/docs/src/walk_through.md b/docs/src/walk_through.md index a478f031..eac189ab 100644 --- a/docs/src/walk_through.md +++ b/docs/src/walk_through.md @@ -65,10 +65,11 @@ Even though the role of a variable (treatment, outcome, confounder, ...) is rela The modeling stage starts from the definition of a Structural Causal Model (`SCM`). This is simply a list of relationships between the random variables in our dataset. See [Structural Causal Models](@ref) for an in-depth explanation. For our purposes, because we know the data generating process, we can define it as follows: ```@example walk-through -scm = SCM( +scm = SCM([ :Y => [:T₁, :T₂, :W₁₁, :W₁₂, :W₂₁, :W₂₂, :C], :T₁ => [:W₁₁, :W₁₂], :T₂ => [:W₂₁, :W₂₂] +] ) ``` From 25acbfd311e16e073789cc71edcc2d47ba483dba Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 7 Nov 2023 16:36:27 +0000 Subject: [PATCH 108/151] add brute force ordering of parameters --- Project.toml | 2 + src/TMLE.jl | 3 ++ src/estimand_ordering.jl | 75 ++++++++++++++++++++++++++++++++++++ test/estimand_ordering.jl | 81 +++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 5 files changed, 162 insertions(+) create mode 100644 src/estimand_ordering.jl create mode 100644 test/estimand_ordering.jl diff --git a/Project.toml b/Project.toml index 603007bb..cd5e4779 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.12.0" [deps] AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" +Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" @@ -41,5 +42,6 @@ PrecompileTools = "1.1.1" PrettyTables = "2.2" TableOperations = "1.2" Tables = "1.6" +Combinatorics = "1.0.2" Zygote = "0.6" julia = "1.6, 1.7, 1" diff --git a/src/TMLE.jl b/src/TMLE.jl index bc98fd28..dcd3e5d8 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -19,6 +19,7 @@ using Random import AbstractDifferentiation as AD using Graphs using MetaGraphsNext +using Combinatorics # ############################################################################# # EXPORTS @@ -38,6 +39,7 @@ export BackdoorAdjustment, identify export last_fluctuation_epsilon export estimands_from_yaml, estimands_to_yaml export to_dict, from_dict!, Configuration +export brute_force_ordering # ############################################################################# # INCLUDES @@ -49,6 +51,7 @@ include("estimands.jl") include("estimators.jl") include("estimates.jl") include("treatment_transformer.jl") +include("estimand_ordering.jl") include("counterfactual_mean_based/estimands.jl") include("counterfactual_mean_based/estimates.jl") diff --git a/src/estimand_ordering.jl b/src/estimand_ordering.jl new file mode 100644 index 00000000..62676b2d --- /dev/null +++ b/src/estimand_ordering.jl @@ -0,0 +1,75 @@ + +function maybe_update_counts!(dict, key) + if haskey(dict, key) + dict[key] += 1 + else + dict[key] = 1 + end +end + +function nuisance_counts(estimands) + η_counts = Dict() + for Ψ in estimands + η = TMLE.get_relevant_factors(Ψ) + for ps in η.propensity_score + maybe_update_counts!(η_counts, ps) + end + maybe_update_counts!(η_counts, η.outcome_mean) + end + return η_counts +end + +""" + evaluate_proxy_costs(estimands, η_counts; verbosity=0) + +This function evaluates proxy measures for both: + - The computational cost (1 per model fit) + - The maximum memory used (1 per model) + +that would be used while performing estimation for an ordered list of estimands. +This is assuming that after each estimation, we purge the cache from models that will +not be subsequently used. That way, we are guaranteed that we minimize computational cost. +""" +function evaluate_proxy_costs(estimands, η_counts; verbosity=0) + η_counts = deepcopy(η_counts) + cache = Set() + maxmem = 0 + compcost = 0 + for (estimand_index, Ψ) ∈ enumerate(estimands) + η = get_relevant_factors(Ψ) + models = (η.propensity_score..., η.outcome_mean) + # Append cache + for model in models + if model ∉ cache + compcost += 1 + push!(cache, model) + end + end + # Update maxmem + maxmem = max(maxmem, length(cache)) + verbosity > 0 && @info string("Cache size after estimand $estimand_index: ", length(cache)) + # Free cache from models that are not useful anymore + for model in models + η_counts[model] -= 1 + if η_counts[model] <= 0 + pop!(cache, model) + end + end + end + + return maxmem, compcost +end + + +function brute_force_ordering(estimands; η_counts = nuisance_counts(estimands)) + optimal_ordering = estimands + optimal_maxmem, optimal_compcost = evaluate_proxy_costs(estimands, η_counts) + for perm ∈ Combinatorics.permutations(estimands) + perm_maxmem, _ = evaluate_proxy_costs(perm, η_counts) + if perm_maxmem < optimal_maxmem + optimal_ordering = perm + optimal_maxmem = perm_maxmem + end + end + return optimal_ordering, optimal_maxmem, optimal_compcost +end \ No newline at end of file diff --git a/test/estimand_ordering.jl b/test/estimand_ordering.jl new file mode 100644 index 00000000..380116ce --- /dev/null +++ b/test/estimand_ordering.jl @@ -0,0 +1,81 @@ +module TestEstimandOrdering + +using TMLE +using Test + +scm = SCM([ + :Y₁ => [:T₁, :T₂, :W₁, :W₂, :C], + :Y₂ => [:T₁, :T₃, :W₃, :W₂, :C], + :T₁ => [:W₁, :W₂], + :T₂ => [:W₂], + :T₃ => [:W₃, :W₂,], +]) +causal_estimands = [ + ATE( + outcome=:Y₁, + treatment_values=(T₁=(case=1, control=0),) + ), + ATE( + outcome=:Y₁, + treatment_values=(T₁=(case=2, control=0),) + ), + CM( + outcome=:Y₁, + treatment_values=(T₁=(case=2, control=0), T₂=(case=1, control=0)) + ), + CM( + outcome=:Y₁, + treatment_values=(T₂=(case=1, control=0),) + ), + IATE( + outcome=:Y₁, + treatment_values=(T₁=(case=2, control=0), T₂=(case=1, control=0)) + ), + ATE( + outcome=:Y₂, + treatment_values=(T₁=(case=2, control=0),) + ), + ATE( + outcome=:Y₂, + treatment_values=(T₃=(case=2, control=0),) + ), + ATE( + outcome=:Y₂, + treatment_values=(T₁=(case=1, control=0), T₃=(case=2, control=0),) + ), +] +statistical_estimands = [identify(x, scm) for x in causal_estimands] + +@testset "Test evaluate_proxy_costs" begin + # Estimand ID || Required models + # 1 || (T₁, Y₁|T₁) + # 2 || (T₁, Y₁|T₁) + # 3 || (T₁, T₂, Y₁|T₁,T₂) + # 4 || (T₂, Y₁|T₂) + # 5 || (T₁, T₂, Y₁|T₁,T₂) + # 6 || (T₁, Y₂|T₁) + # 7 || (T₃, Y₂|T₃) + # 8 || (T₁, T₃, Y₂|T₁,T₃) + + η_counts = TMLE.nuisance_counts(statistical_estimands) + @test η_counts == Dict( + TMLE.ConditionalDistribution(:Y₂, (:T₁, :W₁, :W₂)) => 1, + TMLE.ConditionalDistribution(:Y₂, (:T₁, :T₃, :W₁, :W₂, :W₃)) => 1, + TMLE.ConditionalDistribution(:Y₁, (:T₁, :T₂, :W₁, :W₂)) => 2, + TMLE.ConditionalDistribution(:T₁, (:W₁, :W₂)) => 6, + TMLE.ConditionalDistribution(:T₃, (:W₂, :W₃)) => 2, + TMLE.ConditionalDistribution(:T₂, (:W₂,)) => 3, + TMLE.ConditionalDistribution(:Y₁, (:T₁, :W₁, :W₂)) => 2, + TMLE.ConditionalDistribution(:Y₁, (:T₂, :W₂)) => 1, + TMLE.ConditionalDistribution(:Y₂, (:T₃, :W₂, :W₃)) => 1 + ) + @test TMLE.evaluate_proxy_costs(statistical_estimands, η_counts) == (4, 9) + optimal_ordering, optimal_maxmem, optimal_compcost = brute_force_ordering(statistical_estimands) + @test optimal_maxmem == 3 + @test optimal_compcost == 9 +end + + +end + +true \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 3d11f7ef..e1994cbe 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,7 @@ using Test @test include("treatment_transformer.jl") @test include("scm.jl") @test include("configuration.jl") + @test include("estimand_ordering.jl") @test include("counterfactual_mean_based/estimands.jl") @test include("counterfactual_mean_based/clever_covariate.jl") From 5d0272291fef18bdfaa3cd0df0b6306386dd7c6d Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 7 Nov 2023 17:08:55 +0000 Subject: [PATCH 109/151] add early stopping to brute force --- src/estimand_ordering.jl | 36 ++++++++++++++++++++++++++++++++++-- test/estimand_ordering.jl | 4 +++- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/estimand_ordering.jl b/src/estimand_ordering.jl index 62676b2d..5f57cdd2 100644 --- a/src/estimand_ordering.jl +++ b/src/estimand_ordering.jl @@ -1,4 +1,3 @@ - function maybe_update_counts!(dict, key) if haskey(dict, key) dict[key] += 1 @@ -60,9 +59,37 @@ function evaluate_proxy_costs(estimands, η_counts; verbosity=0) return maxmem, compcost end +""" + get_min_maxmem_lowerbound(estimands) -function brute_force_ordering(estimands; η_counts = nuisance_counts(estimands)) +The maximum number of models for a single estimand is a lower bound +on the cache size. It can be computed in a single pass, i.e. in O(N). +""" +function get_min_maxmem_lowerbound(estimands) + min_maxmem_lowerbound = 0 + for Ψ in estimands + η = get_relevant_factors(Ψ) + candidate_min = length((η.propensity_score..., η.outcome_mean)) + if candidate_min > min_maxmem_lowerbound + min_maxmem_lowerbound = candidate_min + end + end + return min_maxmem_lowerbound +end + +""" + brute_force_ordering(estimands; η_counts = nuisance_counts(estimands)) + +Finds an optimal ordering of the estimands to minimize maximum cache size. +The approach is a brute force one, all permutations are generated and evaluated, +if a minimum is found fast it is immediatly returned. +The theoretical complexity is in O(N!). However due to the stop fast approach and +the shuffling, this is actually expected to be much smaller than that. +""" +function brute_force_ordering(estimands; η_counts=nuisance_counts(estimands), do_shuffle=true, rng=Random.default_rng(), verbosity=0) optimal_ordering = estimands + estimands = do_shuffle ? shuffle(rng, estimands) : estimands + min_maxmem_lowerbound = get_min_maxmem_lowerbound(estimands) optimal_maxmem, optimal_compcost = evaluate_proxy_costs(estimands, η_counts) for perm ∈ Combinatorics.permutations(estimands) perm_maxmem, _ = evaluate_proxy_costs(perm, η_counts) @@ -70,6 +97,11 @@ function brute_force_ordering(estimands; η_counts = nuisance_counts(estimands)) optimal_ordering = perm optimal_maxmem = perm_maxmem end + # Stop fast if the lower bound is reached + if optimal_maxmem == min_maxmem_lowerbound + verbosity > 0 && @info(string("Lower bound reached, stopping.")) + return optimal_ordering, optimal_maxmem, optimal_compcost + end end return optimal_ordering, optimal_maxmem, optimal_compcost end \ No newline at end of file diff --git a/test/estimand_ordering.jl b/test/estimand_ordering.jl index 380116ce..b422cfe2 100644 --- a/test/estimand_ordering.jl +++ b/test/estimand_ordering.jl @@ -2,6 +2,7 @@ module TestEstimandOrdering using TMLE using Test +using StableRNGs scm = SCM([ :Y₁ => [:T₁, :T₂, :W₁, :W₂, :C], @@ -70,7 +71,8 @@ statistical_estimands = [identify(x, scm) for x in causal_estimands] TMLE.ConditionalDistribution(:Y₂, (:T₃, :W₂, :W₃)) => 1 ) @test TMLE.evaluate_proxy_costs(statistical_estimands, η_counts) == (4, 9) - optimal_ordering, optimal_maxmem, optimal_compcost = brute_force_ordering(statistical_estimands) + @test TMLE.get_min_maxmem_lowerbound(statistical_estimands) == 3 + optimal_ordering, optimal_maxmem, optimal_compcost = @test_logs (:info, "Lower bound reached, stopping.") brute_force_ordering(statistical_estimands, verbosity=1, rng=StableRNG(123)) @test optimal_maxmem == 3 @test optimal_compcost == 9 end From df25bf38b8ed6fd96f375cfc77dcf5bc18e5b6c3 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 9 Nov 2023 16:40:46 +0000 Subject: [PATCH 110/151] add brute_force and groups ordering of parameters --- Project.toml | 2 +- src/TMLE.jl | 2 +- src/estimand_ordering.jl | 67 ++++++++++++++++++++++++++++++++++++--- test/estimand_ordering.jl | 20 +++++++++--- 4 files changed, 79 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index cd5e4779..6161d0f4 100644 --- a/Project.toml +++ b/Project.toml @@ -28,6 +28,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractDifferentiation = "0.6.0" CategoricalArrays = "0.10" +Combinatorics = "1.0.2" Distributions = "0.25" GLM = "1.8.2" Graphs = "1.8" @@ -42,6 +43,5 @@ PrecompileTools = "1.1.1" PrettyTables = "2.2" TableOperations = "1.2" Tables = "1.6" -Combinatorics = "1.0.2" Zygote = "0.6" julia = "1.6, 1.7, 1" diff --git a/src/TMLE.jl b/src/TMLE.jl index dcd3e5d8..c61eaf64 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -39,7 +39,7 @@ export BackdoorAdjustment, identify export last_fluctuation_epsilon export estimands_from_yaml, estimands_to_yaml export to_dict, from_dict!, Configuration -export brute_force_ordering +export brute_force_ordering, groups_ordering # ############################################################################# # INCLUDES diff --git a/src/estimand_ordering.jl b/src/estimand_ordering.jl index 5f57cdd2..e33bf756 100644 --- a/src/estimand_ordering.jl +++ b/src/estimand_ordering.jl @@ -77,6 +77,29 @@ function get_min_maxmem_lowerbound(estimands) return min_maxmem_lowerbound end +estimands_permutation_generator(estimands) = Combinatorics.permutations(estimands) + +function get_propensity_score_groups(estimands_and_nuisances) + ps_groups = [[]] + current_ps = estimands_and_nuisances[1][2] + for (index, Ψ_and_ηs) ∈ enumerate(estimands_and_nuisances) + new_ps = Ψ_and_ηs[2] + if new_ps == current_ps + push!(ps_groups[end], index) + else + current_ps = new_ps + push!(ps_groups, [index]) + end + end + return ps_groups +end + +function propensity_score_group_based_permutation_generator(estimands, estimands_and_nuisances) + ps_groups = get_propensity_score_groups(estimands_and_nuisances) + group_permutations = Combinatorics.permutations(collect(1:length(ps_groups))) + return (vcat((estimands[ps_groups[index]] for index in group_perm)...) for group_perm in group_permutations) +end + """ brute_force_ordering(estimands; η_counts = nuisance_counts(estimands)) @@ -86,12 +109,12 @@ if a minimum is found fast it is immediatly returned. The theoretical complexity is in O(N!). However due to the stop fast approach and the shuffling, this is actually expected to be much smaller than that. """ -function brute_force_ordering(estimands; η_counts=nuisance_counts(estimands), do_shuffle=true, rng=Random.default_rng(), verbosity=0) +function brute_force_ordering(estimands; permutation_generator = estimands_permutation_generator(estimands), η_counts=nuisance_counts(estimands), do_shuffle=true, rng=Random.default_rng(), verbosity=0) optimal_ordering = estimands estimands = do_shuffle ? shuffle(rng, estimands) : estimands min_maxmem_lowerbound = get_min_maxmem_lowerbound(estimands) - optimal_maxmem, optimal_compcost = evaluate_proxy_costs(estimands, η_counts) - for perm ∈ Combinatorics.permutations(estimands) + optimal_maxmem, _ = evaluate_proxy_costs(estimands, η_counts) + for perm ∈ permutation_generator perm_maxmem, _ = evaluate_proxy_costs(perm, η_counts) if perm_maxmem < optimal_maxmem optimal_ordering = perm @@ -100,8 +123,42 @@ function brute_force_ordering(estimands; η_counts=nuisance_counts(estimands), d # Stop fast if the lower bound is reached if optimal_maxmem == min_maxmem_lowerbound verbosity > 0 && @info(string("Lower bound reached, stopping.")) - return optimal_ordering, optimal_maxmem, optimal_compcost + return optimal_ordering end end - return optimal_ordering, optimal_maxmem, optimal_compcost + return optimal_ordering +end + +""" + groups_ordering(estimands) + +This will order estimands based on: propensity score first, outcome mean second. This heuristic should +work reasonably well in practice. It could be optimized further by: +- Organising the propensity score groups that share similar components to be close together. +- Brute forcing the ordering of these groups to find an optimal one. +""" +function groups_ordering(estimands; brute_force=false, do_shuffle=true, rng=Random.default_rng(), verbosity=0) + # Sort estimands based on propensity_score first and outcome_mean second + estimands_and_nuisances = [] + for Ψ in estimands + η = TMLE.get_relevant_factors(Ψ) + push!(estimands_and_nuisances, (Ψ, η.propensity_score, η.outcome_mean)) + end + sort!(estimands_and_nuisances, by = x -> (Tuple(TMLE.variables(ps) for ps in x[2]), TMLE.variables(x[3]))) + + # Sorted estimands only + estimands = [x[1] for x in estimands_and_nuisances] + + # Brute force on the propensity score groups + if brute_force + + return brute_force_ordering(estimands; + permutation_generator = propensity_score_group_based_permutation_generator(estimands, estimands_and_nuisances), + do_shuffle=do_shuffle, + rng=rng, + verbosity=verbosity + ) + else + return estimands + end end \ No newline at end of file diff --git a/test/estimand_ordering.jl b/test/estimand_ordering.jl index b422cfe2..d1a924c4 100644 --- a/test/estimand_ordering.jl +++ b/test/estimand_ordering.jl @@ -47,7 +47,7 @@ causal_estimands = [ ] statistical_estimands = [identify(x, scm) for x in causal_estimands] -@testset "Test evaluate_proxy_costs" begin +@testset "Test ordering strategies" begin # Estimand ID || Required models # 1 || (T₁, Y₁|T₁) # 2 || (T₁, Y₁|T₁) @@ -57,7 +57,7 @@ statistical_estimands = [identify(x, scm) for x in causal_estimands] # 6 || (T₁, Y₂|T₁) # 7 || (T₃, Y₂|T₃) # 8 || (T₁, T₃, Y₂|T₁,T₃) - + # ---------------------------------- η_counts = TMLE.nuisance_counts(statistical_estimands) @test η_counts == Dict( TMLE.ConditionalDistribution(:Y₂, (:T₁, :W₁, :W₂)) => 1, @@ -72,9 +72,19 @@ statistical_estimands = [identify(x, scm) for x in causal_estimands] ) @test TMLE.evaluate_proxy_costs(statistical_estimands, η_counts) == (4, 9) @test TMLE.get_min_maxmem_lowerbound(statistical_estimands) == 3 - optimal_ordering, optimal_maxmem, optimal_compcost = @test_logs (:info, "Lower bound reached, stopping.") brute_force_ordering(statistical_estimands, verbosity=1, rng=StableRNG(123)) - @test optimal_maxmem == 3 - @test optimal_compcost == 9 + # The brute force solution returns the optimal solution + optimal_ordering = @test_logs (:info, "Lower bound reached, stopping.") brute_force_ordering(statistical_estimands, verbosity=1, rng=StableRNG(123)) + @test TMLE.evaluate_proxy_costs(optimal_ordering, η_counts) == (3, 9) + # Creating a bad ordering + bad_ordering = statistical_estimands[[1, 7, 3, 6, 2, 5, 8, 4]] + @test TMLE.evaluate_proxy_costs(bad_ordering, η_counts) == (6, 9) + # Without the brute force on groups, the solution is not necessarily optimal + # but still widely improved + ordering_from_groups = groups_ordering(bad_ordering) + @test TMLE.evaluate_proxy_costs(ordering_from_groups, η_counts) == (4, 9) + # Adding a layer of brute forcing results in an optimal ordering + ordering_from_groups_with_brute_force = groups_ordering(bad_ordering, brute_force=true) + @test TMLE.evaluate_proxy_costs(ordering_from_groups_with_brute_force, η_counts) == (3, 9) end From dd6cc3a9c891a1cef8671d4797eee98e02dde5bd Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 9 Nov 2023 16:45:54 +0000 Subject: [PATCH 111/151] change subtitle --- docs/src/user_guide/estimation.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/user_guide/estimation.md b/docs/src/user_guide/estimation.md index 5c5e2a92..cb19633b 100644 --- a/docs/src/user_guide/estimation.md +++ b/docs/src/user_guide/estimation.md @@ -131,7 +131,7 @@ Similarly, one could build CV-OSE: cvose = OSE(models, resampling=CV(nfolds=3)) ``` -## Reusing the SCM +## Caching model fits Let's now see how the `cache` can be reused with a new estimand, say the Total Average Treatment Effect of both `T₁` and `T₂`. From 5d089994a60526910825580e50b7ed8ae6c63c26 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 9 Nov 2023 17:58:54 +0000 Subject: [PATCH 112/151] add default models --- docs/src/index.md | 2 +- docs/src/user_guide/estimation.md | 8 ++--- docs/src/walk_through.md | 4 +-- examples/double_robustness.jl | 2 +- src/counterfactual_mean_based/estimators.jl | 29 +++++++++++++------ src/estimand_ordering.jl | 1 - src/utils.jl | 11 +++++-- test/composition.jl | 6 ++-- .../3points_interactions.jl | 2 +- .../non_regression_test.jl | 20 ++++++------- test/helper_fns.jl | 8 ++--- test/missing_management.jl | 2 +- 12 files changed, 55 insertions(+), 40 deletions(-) diff --git a/docs/src/index.md b/docs/src/index.md index fc8c4024..7f9fcbbf 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -68,7 +68,7 @@ The Average Treatment Effect of ``T`` on ``Y`` confounded by ``W`` is defined as ```@example quick-start models = (Y=with_encoder(LinearRegressor()), T = LogisticClassifier()) -tmle = TMLEE(models) +tmle = TMLEE(models=models) result, _ = tmle(Ψ, dataset, verbosity=0); result ``` diff --git a/docs/src/user_guide/estimation.md b/docs/src/user_guide/estimation.md index cb19633b..4a1187d3 100644 --- a/docs/src/user_guide/estimation.md +++ b/docs/src/user_guide/estimation.md @@ -71,7 +71,7 @@ models = ( T₁=LogisticClassifier(), T₂=LogisticClassifier(), ) -tmle = TMLEE(models) +tmle = TMLEE(models=models) result₁, cache = tmle(Ψ₁, dataset); result₁ nothing # hide @@ -106,7 +106,7 @@ We could now get an interest in the Average Treatment Effect of `T₂` that we w treatment_confounders=(T₂=[:W₂₁, :W₂₂],), outcome_extra_covariates=[:C] ) -ose = OSE(models) +ose = OSE(models=models) result₂, cache = ose(Ψ₂, dataset;cache=cache); result₂ nothing # hide @@ -121,14 +121,14 @@ Both TMLE and OSE can be used with sample-splitting, which, for an additional co To leverage sample-splitting, simply specify a `resampling` strategy when building an estimator: ```@example estimation -cvtmle = TMLEE(models, resampling=CV()) +cvtmle = TMLEE(models=models, resampling=CV()) cvresult₁, _ = cvtmle(Ψ₁, dataset); ``` Similarly, one could build CV-OSE: ```julia -cvose = OSE(models, resampling=CV(nfolds=3)) +cvose = OSE(models=models, resampling=CV(nfolds=3)) ``` ## Caching model fits diff --git a/docs/src/walk_through.md b/docs/src/walk_through.md index eac189ab..923c257a 100644 --- a/docs/src/walk_through.md +++ b/docs/src/walk_through.md @@ -140,7 +140,7 @@ models = ( T₁ = LogisticClassifier(), T₂ = LogisticClassifier() ) -tmle = TMLEE(models) +tmle = TMLEE(models=models) ``` Because we haven't identified the `cm` causal estimand yet, we need to provide the `scm` as well to the estimator: @@ -153,7 +153,7 @@ result Statistical Estimands can be estimated without a ``SCM``, let's use the One-Step estimator: ```@example walk-through -ose = OSE(models) +ose = OSE(models=models) result, cache = ose(statistical_iate, dataset) result ``` diff --git a/examples/double_robustness.jl b/examples/double_robustness.jl index 3a1f5952..9f7682fe 100644 --- a/examples/double_robustness.jl +++ b/examples/double_robustness.jl @@ -158,7 +158,7 @@ function tmle_inference(data) treatment_confounders=(Tcat=[:W],) ) models = (Y=with_encoder(LinearRegressor()), Tcat=LinearBinaryClassifier()) - tmle = TMLEE(models) + tmle = TMLEE(models=models) result, _ = tmle(Ψ, data; verbosity=0) lb, ub = confint(OneSampleTTest(result)) return (TMLE.estimate(result), lb, ub) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 636e978c..c46a436f 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -27,6 +27,18 @@ function get_train_validation_indices(resampling::ResamplingStrategy, factors, d )) end +function acquire_model(models, key, dataset, is_propensity_score) + # If the model is in models return it + haskey(models, key) && return models[key] + # Otherwise, if the required model is for a propensity_score, return the default + model_default = :G_default + if !is_propensity_score + # Finally, if the required model is an outcome_mean, find the type from the data + model_default = is_binary(dataset, key) ? :Q_binary_default : :Q_continuous_default + end + return models[model_default] +end + function (estimator::CMRelevantFactorsEstimator)(estimand, dataset; cache=Dict(), verbosity=1) if haskey(cache, estimand) old_estimator, estimate = cache[estimand] @@ -41,7 +53,7 @@ function (estimator::CMRelevantFactorsEstimator)(estimand, dataset; cache=Dict() train_validation_indices = TMLE.get_train_validation_indices(estimator.resampling, estimand, dataset) # Fit propensity score propensity_score_estimate = Tuple( - ConditionalDistributionEstimator(train_validation_indices, models[factor.outcome])( + ConditionalDistributionEstimator(train_validation_indices, acquire_model(models, factor.outcome, dataset, true))( factor, dataset; cache=cache, @@ -51,7 +63,7 @@ function (estimator::CMRelevantFactorsEstimator)(estimand, dataset; cache=Dict() ) # Fit outcome mean outcome_mean = estimand.outcome_mean - model = models[outcome_mean.outcome] + model = acquire_model(models, outcome_mean.outcome, dataset, false) outcome_mean_estimate = TMLE.ConditionalDistributionEstimator( train_validation_indices, model @@ -110,7 +122,7 @@ mutable struct TMLEE <: Estimator end """ - TMLEE(models; resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing) + TMLEE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing) Defines a TMLE estimator using the specified models for estimation of the nuisance parameters. The estimator is a function that can be applied to estimate estimands for a dataset. @@ -130,12 +142,11 @@ been show to be more robust to positivity violation in practice. ```julia using MLJLinearModels -models = (Y = LinearRegressor(), T = LogisticClassifier()) -tmle = TMLEE(models) +tmle = TMLEE() Ψ̂ₙ, cache = tmle(Ψ, dataset) ``` """ -TMLEE(models; resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing) = +TMLEE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing) = TMLEE(models, resampling, ps_lowerbound, weighted, tol) function (tmle::TMLEE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1) @@ -178,7 +189,7 @@ mutable struct OSE <: Estimator end """ - OSE(models; resampling=nothing, ps_lowerbound=1e-8) + OSE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8) Defines a One Step Estimator using the specified models for estimation of the nuisance parameters. The estimator is a function that can be applied to estimate estimands for a dataset. @@ -196,11 +207,11 @@ result in a data adaptive definition as described in [here](https://pubmed.ncbi. ```julia using MLJLinearModels models = (Y = LinearRegressor(), T = LogisticClassifier()) -ose = OSE(models) +ose = OSE() Ψ̂ₙ, cache = ose(Ψ, dataset) ``` """ -OSE(models; resampling=nothing, ps_lowerbound=1e-8) = +OSE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8) = OSE(models, resampling, ps_lowerbound) function (estimator::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1) diff --git a/src/estimand_ordering.jl b/src/estimand_ordering.jl index e33bf756..928feac2 100644 --- a/src/estimand_ordering.jl +++ b/src/estimand_ordering.jl @@ -151,7 +151,6 @@ function groups_ordering(estimands; brute_force=false, do_shuffle=true, rng=Rand # Brute force on the propensity score groups if brute_force - return brute_force_ordering(estimands; permutation_generator = propensity_score_group_based_permutation_generator(estimands, estimands_and_nuisances), do_shuffle=do_shuffle, diff --git a/src/utils.jl b/src/utils.jl index 9f9807bd..9afb9dca 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -60,11 +60,18 @@ function counterfactualTreatment(vals, T) for (i, name) in enumerate(Tnames)]) end - last_fluctuation(cache) = cache[:last_fluctuation] function last_fluctuation_epsilon(cache) mach = TMLE.last_fluctuation(cache).outcome_mean.machine fp = fitted_params(fitted_params(mach).fitresult.one_dimensional_path) return fp.coef -end \ No newline at end of file +end + +default_models(;Q_binary=LinearBinaryClassifier(), Q_continuous=LinearRegressor(), G=LinearBinaryClassifier(), encoder=encoder()) = ( + Q_binary_default = with_encoder(Q_binary, encoder=encoder), + Q_continuous_default = with_encoder(Q_continuous, encoder=encoder), + G_default = G +) + +is_binary(dataset, columnname) = Set(skipmissing(Tables.getcolumn(dataset, columnname))) == Set([0, 1]) \ No newline at end of file diff --git a/test/composition.jl b/test/composition.jl index 25acb08f..cd5551a7 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -48,8 +48,8 @@ end Y = with_encoder(LinearRegressor()), T = LogisticClassifier(lambda=0) ) - tmle = TMLEE(models) - ose = OSE(models) + tmle = TMLEE(models=models) + ose = OSE(models=models) cache = Dict() CM_tmle_result₁, cache = tmle(CM₁, dataset; cache=cache, verbosity=0) @@ -108,7 +108,7 @@ end Y = with_encoder(LinearRegressor()), T = LogisticClassifier(lambda=0) ) - tmle = TMLEE(models) + tmle = TMLEE(models=models) cache = Dict() CM₁ = CM( diff --git a/test/counterfactual_mean_based/3points_interactions.jl b/test/counterfactual_mean_based/3points_interactions.jl index 38b5ddee..eda40663 100644 --- a/test/counterfactual_mean_based/3points_interactions.jl +++ b/test/counterfactual_mean_based/3points_interactions.jl @@ -43,7 +43,7 @@ end T₃ = LogisticClassifier(lambda=0) ) - tmle = TMLEE(models) + tmle = TMLEE(models=models) result, cache = tmle(Ψ, dataset, verbosity=0); test_coverage(result, Ψ₀) test_fluct_decreases_risk(cache) diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index 11f07964..d49e031c 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -10,7 +10,11 @@ using MLJBase @testset "Test ATE on perinatal dataset." begin # This is a non-regression test which was checked against the R tmle3 package - dataset = CSV.read(joinpath("data", "perinatal.csv"), DataFrame, missingstring=["", "NA"]) + dataset = CSV.read( + joinpath(dirname(dirname(pathof(TMLE))), "test", "data", "perinatal.csv"), + DataFrame, + missingstring=["", "NA"] + ) confounders = [:apgar1, :apgar5, :gagebrth, :mage, :meducyrs, :sexn] dataset.haz01 = categorical(dataset.haz01) dataset.parity01 = categorical(dataset.parity01, ordered=true) @@ -23,25 +27,19 @@ using MLJBase treatment_values=(parity01=(case=1, control=0),), treatment_confounders=(parity01=confounders,) ) - models = ( - haz01 = with_encoder(LinearBinaryClassifier()), - parity01 = LinearBinaryClassifier() - ) + resampling=nothing # No CV ps_lowerbound = 0.025 # Cutoff hardcoded in tmle3 weighted = false # Unweighted fluctuation verbosity = 0 # No logs - tmle = TMLEE(models; + tmle = TMLEE(; resampling=resampling, ps_lowerbound=ps_lowerbound, weighted=weighted ) tmle_result, cache = tmle(Ψ, dataset; verbosity=verbosity); - tmle.models = ( - haz01 = with_encoder(LinearBinaryClassifier()), - parity01 = LinearBinaryClassifier(fit_intercept=false) - ) + @test estimate(tmle_result) ≈ -0.185533 atol = 1e-6 l, u = confint(OneSampleTTest(tmle_result)) @test l ≈ -0.279246 atol = 1e-6 @@ -49,7 +47,7 @@ using MLJBase @test OneSampleZTest(tmle_result) isa OneSampleZTest # Naive - naive = NAIVE(models.haz01) + naive = NAIVE(LinearBinaryClassifier()) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=verbosity) @test naive_result ≈ -0.150078 atol = 1e-6 diff --git a/test/helper_fns.jl b/test/helper_fns.jl index 2ab0ae0a..7600f829 100644 --- a/test/helper_fns.jl +++ b/test/helper_fns.jl @@ -62,10 +62,10 @@ likelihood estimator also solves the score equation. test_fluct_mean_inf_curve_lower_than_initial(tmle_result::TMLE.TMLEstimate, ose_result::TMLE.OSEstimate) = @test abs(mean(tmle_result.IC)) < abs(mean(ose_result.IC)) double_robust_estimators(models; resampling=CV(nfolds=3)) = ( - tmle = TMLEE(models), - ose = OSE(models), - cv_tmle = TMLEE(models, resampling=resampling), - cv_ose = TMLEE(models, resampling=resampling), + tmle = TMLEE(models=models), + ose = OSE(models=models), + cv_tmle = TMLEE(models=models, resampling=resampling), + cv_ose = TMLEE(models=models, resampling=resampling), ) function test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) diff --git a/test/missing_management.jl b/test/missing_management.jl index b601aa8a..b90b9754 100644 --- a/test/missing_management.jl +++ b/test/missing_management.jl @@ -55,7 +55,7 @@ end treatment_values=(T=(case=1, control=0),), treatment_confounders=(T=[:W],)) models=(Y=with_encoder(LinearRegressor()), T=LogisticClassifier(lambda=0)) - tmle = TMLEE(models) + tmle = TMLEE(models=models) tmle_result, cache = tmle(Ψ, dataset; verbosity=0) test_coverage(tmle_result, 1) test_fluct_decreases_risk(cache) From 3cd90b579b04661a12f14a3cda5bc792d663e850 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 10 Nov 2023 16:28:26 +0000 Subject: [PATCH 113/151] add machine data caching option --- src/counterfactual_mean_based/estimators.jl | 50 ++++++++++++------- src/counterfactual_mean_based/fluctuation.jl | 8 +-- src/estimators.jl | 8 +-- .../3points_interactions.jl | 4 +- test/counterfactual_mean_based/fluctuation.jl | 6 ++- .../non_regression_test.jl | 2 +- test/helper_fns.jl | 8 +-- test/missing_management.jl | 2 +- 8 files changed, 54 insertions(+), 34 deletions(-) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index c46a436f..be128e9a 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -39,7 +39,7 @@ function acquire_model(models, key, dataset, is_propensity_score) return models[model_default] end -function (estimator::CMRelevantFactorsEstimator)(estimand, dataset; cache=Dict(), verbosity=1) +function (estimator::CMRelevantFactorsEstimator)(estimand, dataset; cache=Dict(), verbosity=1, machine_cache=false) if haskey(cache, estimand) old_estimator, estimate = cache[estimand] if key(old_estimator) == key(estimator) @@ -57,17 +57,18 @@ function (estimator::CMRelevantFactorsEstimator)(estimand, dataset; cache=Dict() factor, dataset; cache=cache, - verbosity=verbosity + verbosity=verbosity, + machine_cache=machine_cache ) for factor in estimand.propensity_score ) # Fit outcome mean outcome_mean = estimand.outcome_mean model = acquire_model(models, outcome_mean.outcome, dataset, false) - outcome_mean_estimate = TMLE.ConditionalDistributionEstimator( + outcome_mean_estimate = ConditionalDistributionEstimator( train_validation_indices, model - )(outcome_mean, dataset; cache=cache, verbosity=verbosity) + )(outcome_mean, dataset; cache=cache, verbosity=verbosity, machine_cache=machine_cache) # Build estimate estimate = MLCMRelevantFactors(estimand, outcome_mean_estimate, propensity_score_estimate) # Update cache @@ -84,20 +85,22 @@ struct TargetedCMRelevantFactorsEstimator model::Fluctuation end -TargetedCMRelevantFactorsEstimator(Ψ, initial_factors_estimate; tol=nothing, ps_lowerbound=1e-8, weighted=false) = +TargetedCMRelevantFactorsEstimator(Ψ, initial_factors_estimate; tol=nothing, ps_lowerbound=1e-8, weighted=false, machine_cache=false) = TargetedCMRelevantFactorsEstimator(TMLE.Fluctuation(Ψ, initial_factors_estimate; tol=tol, ps_lowerbound=ps_lowerbound, - weighted=weighted + weighted=weighted, + cache=machine_cache )) -function (estimator::TargetedCMRelevantFactorsEstimator)(estimand, dataset; cache=Dict(), verbosity=1) +function (estimator::TargetedCMRelevantFactorsEstimator)(estimand, dataset; cache=Dict(), verbosity=1, machine_cache=false) model = estimator.model # Fluctuate outcome model fluctuated_outcome_mean = MLConditionalDistributionEstimator(model)( model.initial_factors.outcome_mean.estimand, dataset, - verbosity=verbosity + verbosity=verbosity, + machine_cache=machine_cache ) # Do not fluctuate propensity score fluctuated_propensity_score = model.initial_factors.propensity_score @@ -119,10 +122,11 @@ mutable struct TMLEE <: Estimator ps_lowerbound::Union{Float64, Nothing} weighted::Bool tol::Union{Float64, Nothing} + machine_cache::Bool end """ - TMLEE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing) + TMLEE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing, machine_cache=false) Defines a TMLE estimator using the specified models for estimation of the nuisance parameters. The estimator is a function that can be applied to estimate estimands for a dataset. @@ -137,6 +141,7 @@ result in a data adaptive definition as described in [here](https://pubmed.ncbi. - weighted: Whether the fluctuation model is a classig GLM or a weighted version. The weighted fluctuation has been show to be more robust to positivity violation in practice. - tol: This is not used at the moment. +- machine_cache: Whether MLJ.machine created during estimation should cache data. # Example @@ -146,8 +151,8 @@ tmle = TMLEE() Ψ̂ₙ, cache = tmle(Ψ, dataset) ``` """ -TMLEE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing) = - TMLEE(models, resampling, ps_lowerbound, weighted, tol) +TMLEE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, weighted=false, tol=nothing, machine_cache=false) = + TMLEE(models, resampling, ps_lowerbound, weighted, tol, machine_cache) function (tmle::TMLEE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1) # Check the estimand against the dataset @@ -157,7 +162,11 @@ function (tmle::TMLEE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict() nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(relevant_factors)) initial_factors_dataset = TMLE.choose_initial_dataset(dataset, nomissing_dataset, tmle.resampling) initial_factors_estimator = TMLE.CMRelevantFactorsEstimator(tmle.resampling, tmle.models) - initial_factors_estimate = initial_factors_estimator(relevant_factors, initial_factors_dataset; cache=cache, verbosity=verbosity) + initial_factors_estimate = initial_factors_estimator(relevant_factors, initial_factors_dataset; + cache=cache, + verbosity=verbosity, + machine_cache=tmle.machine_cache + ) # Get propensity score truncation threshold n = nrows(nomissing_dataset) ps_lowerbound = TMLE.ps_lower_bound(n, tmle.ps_lowerbound) @@ -168,9 +177,14 @@ function (tmle::TMLEE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict() initial_factors_estimate; tol=tmle.tol, ps_lowerbound=tmle.ps_lowerbound, - weighted=tmle.weighted + weighted=tmle.weighted, + machine_cache=tmle.machine_cache ) - targeted_factors_estimate = targeted_factors_estimator(relevant_factors, nomissing_dataset; cache=cache, verbosity=verbosity) + targeted_factors_estimate = targeted_factors_estimator(relevant_factors, nomissing_dataset; + cache=cache, + verbosity=verbosity, + machine_cache=tmle.machine_cache + ) # Estimation results after TMLE IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, targeted_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) verbosity >= 1 && @info "Done." @@ -186,10 +200,11 @@ mutable struct OSE <: Estimator models::NamedTuple resampling::Union{Nothing, ResamplingStrategy} ps_lowerbound::Union{Float64, Nothing} + machine_cache::Bool end """ - OSE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8) + OSE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, machine_cache=false) Defines a One Step Estimator using the specified models for estimation of the nuisance parameters. The estimator is a function that can be applied to estimate estimands for a dataset. @@ -201,6 +216,7 @@ function that can be applied to estimate estimands for a dataset. any valid `MLJ.ResamplingStrategy` will result in CV-OSE. - ps_lowerbound: Lowerbound for the propensity score to avoid division by 0. The special value `nothing` will result in a data adaptive definition as described in [here](https://pubmed.ncbi.nlm.nih.gov/35512316/). +- machine_cache: Whether MLJ.machine created during estimation should cache data. # Example @@ -211,8 +227,8 @@ ose = OSE() Ψ̂ₙ, cache = ose(Ψ, dataset) ``` """ -OSE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8) = - OSE(models, resampling, ps_lowerbound) +OSE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, machine_cache=false) = + OSE(models, resampling, ps_lowerbound, machine_cache) function (estimator::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1) # Check the estimand against the dataset diff --git a/src/counterfactual_mean_based/fluctuation.jl b/src/counterfactual_mean_based/fluctuation.jl index 63613064..b693d47c 100644 --- a/src/counterfactual_mean_based/fluctuation.jl +++ b/src/counterfactual_mean_based/fluctuation.jl @@ -4,10 +4,11 @@ mutable struct Fluctuation <: MLJBase.Supervised tol::Union{Nothing, Float64} ps_lowerbound::Float64 weighted::Bool + cache::Bool end -Fluctuation(Ψ, initial_factors; tol=nothing, ps_lowerbound=1e-8, weighted=false) = - Fluctuation(Ψ, initial_factors, tol, ps_lowerbound, weighted) +Fluctuation(Ψ, initial_factors; tol=nothing, ps_lowerbound=1e-8, weighted=false, cache=false) = + Fluctuation(Ψ, initial_factors, tol, ps_lowerbound, weighted, cache) one_dimensional_path(target_scitype::Type{T}) where T <: AbstractVector{<:MLJBase.Continuous} = LinearRegressor(fit_intercept=false, offsetcol = :offset) one_dimensional_path(target_scitype::Type{T}) where T <: AbstractVector{<:Finite} = LinearBinaryClassifier(fit_intercept=false, offsetcol = :offset) @@ -43,7 +44,8 @@ function MLJBase.fit(model::Fluctuation, verbosity, X, y) one_dimensional_path(scitype(y)), clever_covariate_and_offset, y, - weights, + weights, + cache=model.cache ) fit!(mach, verbosity=verbosity) diff --git a/src/estimators.jl b/src/estimators.jl index 02e0cf67..933b7858 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -8,7 +8,7 @@ struct MLConditionalDistributionEstimator <: Estimator model::MLJBase.Supervised end -function (estimator::MLConditionalDistributionEstimator)(estimand, dataset; cache=Dict(), verbosity=1) +function (estimator::MLConditionalDistributionEstimator)(estimand, dataset; cache=Dict(), verbosity=1, machine_cache=false) # Lookup in cache if haskey(cache, estimand) old_estimator, estimate = cache[estimand] @@ -23,7 +23,7 @@ function (estimator::MLConditionalDistributionEstimator)(estimand, dataset; cach # Fit Conditional DIstribution using MLJ X = selectcols(relevant_dataset, estimand.parents) y = Tables.getcolumn(relevant_dataset, estimand.outcome) - mach = machine(estimator.model, X, y) + mach = machine(estimator.model, X, y, cache=machine_cache) fit!(mach, verbosity=verbosity-1) # Build estimate estimate = MLConditionalDistribution(estimand, mach) @@ -45,7 +45,7 @@ struct SampleSplitMLConditionalDistributionEstimator <: Estimator train_validation_indices::Tuple end -function (estimator::SampleSplitMLConditionalDistributionEstimator)(estimand, dataset; cache=Dict(), verbosity=1) +function (estimator::SampleSplitMLConditionalDistributionEstimator)(estimand, dataset; cache=Dict(), verbosity=1, machine_cache=false) # Lookup in cache if haskey(cache, estimand) old_estimator, estimate = cache[estimand] @@ -65,7 +65,7 @@ function (estimator::SampleSplitMLConditionalDistributionEstimator)(estimand, da train_dataset = selectrows(relevant_dataset, train_indices) Xtrain = selectcols(train_dataset, estimand.parents) ytrain = Tables.getcolumn(train_dataset, estimand.outcome) - mach = machine(estimator.model, Xtrain, ytrain) + mach = machine(estimator.model, Xtrain, ytrain, cache=machine_cache) fit!(mach, verbosity=verbosity-1) machines[index] = mach end diff --git a/test/counterfactual_mean_based/3points_interactions.jl b/test/counterfactual_mean_based/3points_interactions.jl index eda40663..436aaa6d 100644 --- a/test/counterfactual_mean_based/3points_interactions.jl +++ b/test/counterfactual_mean_based/3points_interactions.jl @@ -10,7 +10,7 @@ using CategoricalArrays using Test using LogExpFunctions -include(joinpath(dirname(@__DIR__), "helper_fns.jl")) +include(joinpath(dirname(dirname(pathof(TMLE))), "test", "helper_fns.jl")) function dataset_scm_and_truth(;n=1000) rng = StableRNG(123) @@ -43,7 +43,7 @@ end T₃ = LogisticClassifier(lambda=0) ) - tmle = TMLEE(models=models) + tmle = TMLEE(models=models, machine_cache=true) result, cache = tmle(Ψ, dataset, verbosity=0); test_coverage(result, Ψ₀) test_fluct_decreases_risk(cache) diff --git a/test/counterfactual_mean_based/fluctuation.jl b/test/counterfactual_mean_based/fluctuation.jl index c20423a8..4ea5461b 100644 --- a/test/counterfactual_mean_based/fluctuation.jl +++ b/test/counterfactual_mean_based/fluctuation.jl @@ -38,7 +38,8 @@ using DataFrames fluctuation = TMLE.Fluctuation(Ψ, η̂ₙ; tol=nothing, ps_lowerbound=ps_lowerbound, - weighted=true, + weighted=true, + cache=true ) fitresult, cache, report = MLJBase.fit(fluctuation, 0, X, y) @test fitted_params(fitresult.one_dimensional_path).features == [:covariate] @@ -102,7 +103,8 @@ end fluctuation = TMLE.Fluctuation(Ψ, η̂ₙ; tol=nothing, ps_lowerbound=1e-8, - weighted=false, + weighted=false, + cache=true ) fitresult, cache, report = MLJBase.fit(fluctuation, 0, X, y) @test cache.weighted_covariate ≈ [2.45, -3.27, -3.27, 2.45, 2.45, -6.13, 8.17] atol=0.01 diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index d49e031c..749c62ec 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -47,7 +47,7 @@ using MLJBase @test OneSampleZTest(tmle_result) isa OneSampleZTest # Naive - naive = NAIVE(LinearBinaryClassifier()) + naive = NAIVE(with_encoder(LinearBinaryClassifier())) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=verbosity) @test naive_result ≈ -0.150078 atol = 1e-6 diff --git a/test/helper_fns.jl b/test/helper_fns.jl index 7600f829..f3f11fe3 100644 --- a/test/helper_fns.jl +++ b/test/helper_fns.jl @@ -62,10 +62,10 @@ likelihood estimator also solves the score equation. test_fluct_mean_inf_curve_lower_than_initial(tmle_result::TMLE.TMLEstimate, ose_result::TMLE.OSEstimate) = @test abs(mean(tmle_result.IC)) < abs(mean(ose_result.IC)) double_robust_estimators(models; resampling=CV(nfolds=3)) = ( - tmle = TMLEE(models=models), - ose = OSE(models=models), - cv_tmle = TMLEE(models=models, resampling=resampling), - cv_ose = TMLEE(models=models, resampling=resampling), + tmle = TMLEE(models=models, machine_cache=true), + ose = OSE(models=models, machine_cache=true), + cv_tmle = TMLEE(models=models, resampling=resampling, machine_cache=true), + cv_ose = TMLEE(models=models, resampling=resampling, machine_cache=true), ) function test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) diff --git a/test/missing_management.jl b/test/missing_management.jl index b90b9754..b54fc629 100644 --- a/test/missing_management.jl +++ b/test/missing_management.jl @@ -55,7 +55,7 @@ end treatment_values=(T=(case=1, control=0),), treatment_confounders=(T=[:W],)) models=(Y=with_encoder(LinearRegressor()), T=LogisticClassifier(lambda=0)) - tmle = TMLEE(models=models) + tmle = TMLEE(models=models, machine_cache=true) tmle_result, cache = tmle(Ψ, dataset; verbosity=0) test_coverage(tmle_result, 1) test_fluct_decreases_risk(cache) From 2a74417ada8c669e323895468618c014d862eeb4 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 13 Nov 2023 13:38:39 +0000 Subject: [PATCH 114/151] add YAML extension --- Project.toml | 9 +++++++++ ext/TMLEYAMLExt.jl | 24 ++++++++++++++++++++++++ src/TMLE.jl | 2 +- src/configuration.jl | 5 ++++- test/configuration.jl | 29 ++++++++++++----------------- 5 files changed, 50 insertions(+), 19 deletions(-) create mode 100644 ext/TMLEYAMLExt.jl diff --git a/Project.toml b/Project.toml index 6161d0f4..1b12fe6e 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,14 @@ TableOperations = "ab02a1b2-a7df-11e8-156e-fb1833f50b87" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +[weakdeps] + +YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" + +[extensions] + +TMLEYAMLExt = "YAML" + [compat] AbstractDifferentiation = "0.6.0" CategoricalArrays = "0.10" @@ -43,5 +51,6 @@ PrecompileTools = "1.1.1" PrettyTables = "2.2" TableOperations = "1.2" Tables = "1.6" +YAML = "0.4.9" Zygote = "0.6" julia = "1.6, 1.7, 1" diff --git a/ext/TMLEYAMLExt.jl b/ext/TMLEYAMLExt.jl new file mode 100644 index 00000000..476f5992 --- /dev/null +++ b/ext/TMLEYAMLExt.jl @@ -0,0 +1,24 @@ +module TMLEYAMLExt + +using YAML +using TMLE + +""" + load_estimands_config(file) + +Loads a YAML configuration file containing: + - Estimands + - An SCM (optional) + - An Adjustment Method (optional) +""" +TMLE.configuration_from_yaml(file) = from_dict!(YAML.load_file(file, dicttype=Dict{Symbol, Any})) + +""" + configuration_to_yaml(file, config::Configuration) + +Writes a `Configuration` struct to a YAML file. The latter can be deserialized +with `configuration_from_yaml`. +""" +TMLE.configuration_to_yaml(file, config::Configuration) = YAML.write_file(file, to_dict(config)) + +end \ No newline at end of file diff --git a/src/TMLE.jl b/src/TMLE.jl index c61eaf64..7921ae95 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -37,7 +37,7 @@ export compose export TreatmentTransformer, with_encoder, encoder export BackdoorAdjustment, identify export last_fluctuation_epsilon -export estimands_from_yaml, estimands_to_yaml +export configuration_from_yaml, configuration_to_yaml export to_dict, from_dict!, Configuration export brute_force_ordering, groups_ordering diff --git a/src/configuration.jl b/src/configuration.jl index d0a67c13..acd3fe4a 100644 --- a/src/configuration.jl +++ b/src/configuration.jl @@ -33,4 +33,7 @@ function from_dict!(d::Dict{T, Any}) where T haskey(d, T(:type)) || return d constructor = eval(Meta.parse(pop!(d, :type))) return constructor(;[key => from_dict!(val) for (key, val) in d]...) -end \ No newline at end of file +end + +function configuration_from_yaml end +function configuration_to_yaml end \ No newline at end of file diff --git a/test/configuration.jl b/test/configuration.jl index 3e9b3c8b..2e24eff1 100644 --- a/test/configuration.jl +++ b/test/configuration.jl @@ -40,14 +40,13 @@ using YAML estimands=estimands ) config_dict = to_dict(configuration) - YAML.write_file(outfilename, config_dict) - dict_from_yaml = YAML.load_file(outfilename, dicttype=Dict{Symbol,Any}) - reconstructed_configuration = from_dict!(dict_from_yaml) - @test reconstructed_configuration.scm === nothing - @test reconstructed_configuration.adjustment === nothing + configuration_to_yaml(outfilename, configuration) + config_from_yaml = configuration_from_yaml(outfilename) + @test config_from_yaml.scm === nothing + @test config_from_yaml.adjustment === nothing for index in 1:length(estimands) estimand = configuration.estimands[index] - reconstructed_estimand = reconstructed_configuration.estimands[index] + reconstructed_estimand = config_from_yaml.estimands[index] @test estimand.outcome == reconstructed_estimand.outcome for treatment in keys(estimand.treatment_values) @test estimand.treatment_values[treatment] == reconstructed_estimand.treatment_values[treatment] @@ -81,18 +80,14 @@ using YAML configuration = Configuration( estimands=estimands, + scm = StaticSCM( + confounders = [:W], + outcomes = [:Y1], + treatments = [:T1, :T2]), adjustment=BackdoorAdjustment([:C1, :C2]) - ) - config_dict = to_dict(configuration) - config_dict[:scm] = Dict( - :type => "StaticSCM", - :confounders => [:W], - :outcomes => [:Y1], - :treatments => [:T1, :T2] - ) - YAML.write_file(outfilename, config_dict) - config_from_yaml = from_dict!(YAML.load_file(outfilename, dicttype=Dict{Symbol,Any})) - + ) + configuration_to_yaml(outfilename, configuration) + config_from_yaml = configuration_from_yaml(outfilename) scm = config_from_yaml.scm adjustment = config_from_yaml.adjustment # Estimand 1 From 082ceb8eececdc015161ac5d8be528da81fe0907 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 13 Nov 2023 15:37:44 +0000 Subject: [PATCH 115/151] add YAML to extra section for bacwards compatibility --- Project.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Project.toml b/Project.toml index 1b12fe6e..bf194d02 100644 --- a/Project.toml +++ b/Project.toml @@ -54,3 +54,7 @@ Tables = "1.6" YAML = "0.4.9" Zygote = "0.6" julia = "1.6, 1.7, 1" + +[extras] + +YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" From 587f9af2234e72de07ac3c33b9ce8edb7a4ef790 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 13 Nov 2023 16:14:15 +0000 Subject: [PATCH 116/151] limit config test to 1.9 --- test/runtests.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index e1994cbe..8ecf3516 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,10 @@ using Test @test include("composition.jl") @test include("treatment_transformer.jl") @test include("scm.jl") - @test include("configuration.jl") + # Requires Extensions + if VERSION >= v"1.9" + @test include("configuration.jl") + end @test include("estimand_ordering.jl") @test include("counterfactual_mean_based/estimands.jl") From c44a99393cb41f4b2e57a5555b0dffda4f6cd1df Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 17 Nov 2023 21:02:15 +0000 Subject: [PATCH 117/151] add json serialization and std field to estimates --- Project.toml | 16 ++- ext/JSONExt.jl | 28 ++++++ ext/{TMLEYAMLExt.jl => YAMLExt.jl} | 6 +- src/TMLE.jl | 1 + src/configuration.jl | 8 +- src/counterfactual_mean_based/estimates.jl | 4 + src/counterfactual_mean_based/estimators.jl | 6 +- test/Project.toml | 3 +- test/composition.jl | 4 +- test/configuration.jl | 103 +++++++++++++------- test/runtests.jl | 41 ++++---- 11 files changed, 152 insertions(+), 68 deletions(-) create mode 100644 ext/JSONExt.jl rename ext/{TMLEYAMLExt.jl => YAMLExt.jl} (74%) diff --git a/Project.toml b/Project.toml index bf194d02..882fb998 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" GLM = "38e38edf-8417-5370-95a0-9cbb8c7f171a" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" +JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c" @@ -26,12 +27,15 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] - +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" +JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" [extensions] - -TMLEYAMLExt = "YAML" +GraphMakieExt = ["GraphMakie", "CairoMakie"] +YAMLExt = "YAML" +JSONExt = "JSON" [compat] AbstractDifferentiation = "0.6.0" @@ -52,9 +56,13 @@ PrettyTables = "2.2" TableOperations = "1.2" Tables = "1.6" YAML = "0.4.9" +JSON = "0.21.4" Zygote = "0.6" julia = "1.6, 1.7, 1" [extras] - +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" +GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" +JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" + diff --git a/ext/JSONExt.jl b/ext/JSONExt.jl new file mode 100644 index 00000000..bdedb00b --- /dev/null +++ b/ext/JSONExt.jl @@ -0,0 +1,28 @@ +module JSONExt + +using JSON +using TMLE + +""" + configuration_from_json(file) + +Loads a YAML configuration file containing: + - Estimands + - An SCM (optional) + - An Adjustment Method (optional) +""" +TMLE.configuration_from_json(file) = from_dict!(JSON.parsefile(file, dicttype=Dict{Symbol, Any})) + +""" + configuration_to_json(file, config::Configuration) + +Writes a `Configuration` struct to a YAML file. The latter can be deserialized +with `configuration_from_yaml`. +""" +function TMLE.configuration_to_json(file, config; indent=1) + open(file, "w") do io + JSON.print(io, to_dict(config), indent) + end +end + +end \ No newline at end of file diff --git a/ext/TMLEYAMLExt.jl b/ext/YAMLExt.jl similarity index 74% rename from ext/TMLEYAMLExt.jl rename to ext/YAMLExt.jl index 476f5992..d4d97922 100644 --- a/ext/TMLEYAMLExt.jl +++ b/ext/YAMLExt.jl @@ -1,10 +1,10 @@ -module TMLEYAMLExt +module YAMLExt using YAML using TMLE """ - load_estimands_config(file) + configuration_from_yaml(file) Loads a YAML configuration file containing: - Estimands @@ -19,6 +19,6 @@ TMLE.configuration_from_yaml(file) = from_dict!(YAML.load_file(file, dicttype=Di Writes a `Configuration` struct to a YAML file. The latter can be deserialized with `configuration_from_yaml`. """ -TMLE.configuration_to_yaml(file, config::Configuration) = YAML.write_file(file, to_dict(config)) +TMLE.configuration_to_yaml(file, config) = YAML.write_file(file, to_dict(config)) end \ No newline at end of file diff --git a/src/TMLE.jl b/src/TMLE.jl index 7921ae95..317f32f3 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -38,6 +38,7 @@ export TreatmentTransformer, with_encoder, encoder export BackdoorAdjustment, identify export last_fluctuation_epsilon export configuration_from_yaml, configuration_to_yaml +export configuration_from_json, configuration_to_json export to_dict, from_dict!, Configuration export brute_force_ordering, groups_ordering diff --git a/src/configuration.jl b/src/configuration.jl index acd3fe4a..dac88542 100644 --- a/src/configuration.jl +++ b/src/configuration.jl @@ -20,6 +20,10 @@ function to_dict(configuration::Configuration) return config_dict end +to_dict(x) = x + +to_dict(v::AbstractVector) = [to_dict(x) for x in v] + from_dict!(x) = x from_dict!(v::AbstractVector) = [from_dict!(x) for x in v] @@ -36,4 +40,6 @@ function from_dict!(d::Dict{T, Any}) where T end function configuration_from_yaml end -function configuration_to_yaml end \ No newline at end of file +function configuration_to_yaml end +function configuration_from_json end +function configuration_to_json end \ No newline at end of file diff --git a/src/counterfactual_mean_based/estimates.jl b/src/counterfactual_mean_based/estimates.jl index 55002d35..53b6bfc3 100644 --- a/src/counterfactual_mean_based/estimates.jl +++ b/src/counterfactual_mean_based/estimates.jl @@ -26,17 +26,21 @@ string_repr(estimate::MLCMRelevantFactors) = string( struct TMLEstimate{T<:AbstractFloat} <: Estimate estimand::StatisticalCMCompositeEstimand Ψ̂::T + σ::T IC::Vector{T} end struct OSEstimate{T<:AbstractFloat} <: Estimate estimand::StatisticalCMCompositeEstimand Ψ̂::T + σ̂::T IC::Vector{T} end const EICEstimate = Union{TMLEstimate, OSEstimate} +without_ic(estimate::E) where E <: EICEstimate = E(estimate.estimand, estimate.Ψ̂, estimate.σ̂, []) + function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate) testresult = OneSampleTTest(est) data = [estimate(est) confint(testresult) pvalue(testresult);] diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index be128e9a..5046dd88 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -187,9 +187,10 @@ function (tmle::TMLEE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict() ) # Estimation results after TMLE IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, targeted_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) + σ̂ = std(IC) / sqrt(size(IC, 1)) verbosity >= 1 && @info "Done." # update!(cache, relevant_factors, targeted_factors_estimate) - return TMLEstimate(Ψ, Ψ̂, IC), cache + return TMLEstimate(Ψ, Ψ̂, σ̂, IC), cache end ##################################################################### @@ -250,8 +251,9 @@ function (estimator::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dic # Gradient and estimate IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) + σ̂ = std(IC) / sqrt(size(IC, 1)) verbosity >= 1 && @info "Done." - return OSEstimate(Ψ, Ψ̂ + mean(IC), IC), cache + return OSEstimate(Ψ, Ψ̂ + mean(IC), σ̂, IC), cache end ##################################################################### diff --git a/test/Project.toml b/test/Project.toml index f7b8c11d..e0d9f184 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -5,6 +5,7 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" +JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c" @@ -26,6 +27,6 @@ CSV = "0.10" DataFrames = "1.5" MLJLinearModels = "0.9" StableRNGs = "1.0" +StatisticalMeasures = "0.1.3" StatsBase = "0.33" YAML = "0.4.9" -StatisticalMeasures = "0.1.3" \ No newline at end of file diff --git a/test/composition.jl b/test/composition.jl index cd5551a7..36cd93be 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -29,8 +29,8 @@ end ) n = 10 X = rand(n, 2) - ER₁ = TMLE.TMLEstimate(Ψ, 1., X[:, 1]) - ER₂ = TMLE.TMLEstimate(Ψ, 0., X[:, 2]) + ER₁ = TMLE.TMLEstimate(Ψ, 1., 1., X[:, 1]) + ER₂ = TMLE.TMLEstimate(Ψ, 0., 1., X[:, 2]) Σ = cov(ER₁, ER₂) @test size(Σ) == (2, 2) @test Σ == cov(X) diff --git a/test/configuration.jl b/test/configuration.jl index 2e24eff1..4721d9f1 100644 --- a/test/configuration.jl +++ b/test/configuration.jl @@ -3,9 +3,25 @@ module TestConfiguration using Test using TMLE using YAML +using JSON + +function check_estimands(reconstructed_estimands, estimands) + for (index, estimand) in enumerate(estimands) + reconstructed_estimand = reconstructed_estimands[index] + @test estimand.outcome == reconstructed_estimand.outcome + for treatment in keys(estimand.treatment_values) + @test estimand.treatment_values[treatment] == reconstructed_estimand.treatment_values[treatment] + @test estimand.treatment_confounders[treatment] == reconstructed_estimand.treatment_confounders[treatment] + @test estimand.outcome_extra_covariates == reconstructed_estimand.outcome_extra_covariates + end + end +end + +outprefix = "test_serialized." +yamlfilename = string(outprefix, "yaml") +jsonfilename = string(outprefix, "json") @testset "Test Configurations" begin - outfilename = "test_configuration.yml" estimands = [ IATE( outcome=:Y1, @@ -36,26 +52,35 @@ using YAML treatment_confounders=(T3 = [:W], T1 = [:W]) ) ] + configuration_to_yaml(yamlfilename, estimands) + estimands_from_yaml = configuration_from_yaml(yamlfilename) + check_estimands(estimands_from_yaml, estimands) + rm(yamlfilename) + configuration_to_json(jsonfilename, estimands) + estimands_from_json = configuration_from_json(jsonfilename) + check_estimands(estimands_from_json, estimands) + rm(jsonfilename) + configuration = Configuration( estimands=estimands ) config_dict = to_dict(configuration) - configuration_to_yaml(outfilename, configuration) - config_from_yaml = configuration_from_yaml(outfilename) - @test config_from_yaml.scm === nothing - @test config_from_yaml.adjustment === nothing - for index in 1:length(estimands) - estimand = configuration.estimands[index] - reconstructed_estimand = config_from_yaml.estimands[index] - @test estimand.outcome == reconstructed_estimand.outcome - for treatment in keys(estimand.treatment_values) - @test estimand.treatment_values[treatment] == reconstructed_estimand.treatment_values[treatment] - @test estimand.treatment_confounders[treatment] == reconstructed_estimand.treatment_confounders[treatment] - @test estimand.outcome_extra_covariates == reconstructed_estimand.outcome_extra_covariates - end + configuration_to_yaml(yamlfilename, configuration) + configuration_to_json(jsonfilename, configuration) + config_from_yaml = configuration_from_yaml(yamlfilename) + config_from_json = configuration_from_json(jsonfilename) + for loaded_config in (config_from_yaml, config_from_json) + @test loaded_config.scm === nothing + @test loaded_config.adjustment === nothing + reconstructed_estimands = loaded_config.estimands + check_estimands(reconstructed_estimands, estimands) end - rm(outfilename) + rm(yamlfilename) + rm(jsonfilename) + +end +@testset "Test with an SCM and causal estimands" begin # With a StaticSCM, some Causal estimands and an Adjustment Method estimands = [ IATE( @@ -86,28 +111,34 @@ using YAML treatments = [:T1, :T2]), adjustment=BackdoorAdjustment([:C1, :C2]) ) - configuration_to_yaml(outfilename, configuration) - config_from_yaml = configuration_from_yaml(outfilename) - scm = config_from_yaml.scm - adjustment = config_from_yaml.adjustment - # Estimand 1 - estimand₁ = config_from_yaml.estimands[1] - statistical_estimand₁ = identify(adjustment, estimand₁, scm) - @test statistical_estimand₁.outcome == :Y1 - @test statistical_estimand₁.treatment_values.T1 == (case=1, control=0) - @test statistical_estimand₁.treatment_values.T2 == (case=1, control=0) - @test statistical_estimand₁.treatment_confounders.T2 == (:W,) - @test statistical_estimand₁.treatment_confounders.T1 == (:W,) - @test statistical_estimand₁.outcome_extra_covariates == (:C1, :C2) - # Estimand 2 - estimand₂ = config_from_yaml.estimands[2] - @test estimand₂.outcome == :Y3 - @test estimand₂.treatment_values.T1 == (case = 1, control = 0) - @test estimand₂.treatment_values.T3 == (case = "AC", control = "CC") - @test estimand₂.treatment_confounders.T1 == (:W,) - @test estimand₂.treatment_confounders.T3 == (:W,) + configuration_to_yaml(yamlfilename, configuration) + configuration_to_json(jsonfilename, configuration) - rm(outfilename) + config_from_yaml = configuration_from_yaml(yamlfilename) + config_from_json = configuration_from_json(jsonfilename) + + for loaded_config in (config_from_yaml, config_from_json) + scm = loaded_config.scm + adjustment = loaded_config.adjustment + # Estimand 1 + estimand₁ = loaded_config.estimands[1] + statistical_estimand₁ = identify(adjustment, estimand₁, scm) + @test statistical_estimand₁.outcome == :Y1 + @test statistical_estimand₁.treatment_values.T1 == (case=1, control=0) + @test statistical_estimand₁.treatment_values.T2 == (case=1, control=0) + @test statistical_estimand₁.treatment_confounders.T2 == (:W,) + @test statistical_estimand₁.treatment_confounders.T1 == (:W,) + @test statistical_estimand₁.outcome_extra_covariates == (:C1, :C2) + # Estimand 2 + estimand₂ = loaded_config.estimands[2] + @test estimand₂.outcome == :Y3 + @test estimand₂.treatment_values.T1 == (case = 1, control = 0) + @test estimand₂.treatment_values.T3 == (case = "AC", control = "CC") + @test estimand₂.treatment_confounders.T1 == (:W,) + @test estimand₂.treatment_confounders.T3 == (:W,) + end + rm(yamlfilename) + rm(jsonfilename) end end diff --git a/test/runtests.jl b/test/runtests.jl index 8ecf3516..9049d620 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,27 +1,30 @@ using Test +using TMLE + +TEST_DIR = joinpath(pkgdir(TMLE), "test") @time begin - @test include("utils.jl") - @test include("estimands.jl") - @test include("estimators_and_estimates.jl") - @test include("missing_management.jl") - @test include("composition.jl") - @test include("treatment_transformer.jl") - @test include("scm.jl") + @test include(joinpath(TEST_DIR, "utils.jl")) + @test include(joinpath(TEST_DIR, "estimands.jl")) + @test include(joinpath(TEST_DIR, "estimators_and_estimates.jl")) + @test include(joinpath(TEST_DIR, "missing_management.jl")) + @test include(joinpath(TEST_DIR, "composition.jl")) + @test include(joinpath(TEST_DIR, "treatment_transformer.jl")) + @test include(joinpath(TEST_DIR, "scm.jl")) # Requires Extensions if VERSION >= v"1.9" - @test include("configuration.jl") + @test include(joinpath(TEST_DIR, "configuration.jl")) end - @test include("estimand_ordering.jl") + @test include(joinpath(TEST_DIR, "estimand_ordering.jl")) - @test include("counterfactual_mean_based/estimands.jl") - @test include("counterfactual_mean_based/clever_covariate.jl") - @test include("counterfactual_mean_based/gradient.jl") - @test include("counterfactual_mean_based/fluctuation.jl") - @test include("counterfactual_mean_based/estimators_and_estimates.jl") - @test include("counterfactual_mean_based/non_regression_test.jl") - @test include("counterfactual_mean_based/double_robustness_ate.jl") - @test include("counterfactual_mean_based/double_robustness_iate.jl") - @test include("counterfactual_mean_based/3points_interactions.jl") - @test include("counterfactual_mean_based/adjustment.jl") + @test include(joinpath(TEST_DIR, "counterfactual_mean_based/estimands.jl")) + @test include(joinpath(TEST_DIR, "counterfactual_mean_based/clever_covariate.jl")) + @test include(joinpath(TEST_DIR, "counterfactual_mean_based/gradient.jl")) + @test include(joinpath(TEST_DIR, "counterfactual_mean_based/fluctuation.jl")) + @test include(joinpath(TEST_DIR, "counterfactual_mean_based/estimators_and_estimates.jl")) + @test include(joinpath(TEST_DIR, "counterfactual_mean_based/non_regression_test.jl")) + @test include(joinpath(TEST_DIR, "counterfactual_mean_based/double_robustness_ate.jl")) + @test include(joinpath(TEST_DIR, "counterfactual_mean_based/double_robustness_iate.jl")) + @test include(joinpath(TEST_DIR, "counterfactual_mean_based/3points_interactions.jl")) + @test include(joinpath(TEST_DIR, "counterfactual_mean_based/adjustment.jl")) end \ No newline at end of file From 0118f931222891ec630b6ba14a5bb59d7f364dc1 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 17 Nov 2023 21:39:31 +0000 Subject: [PATCH 118/151] change some field names --- src/counterfactual_mean_based/estimates.jl | 26 +++++++++++---------- src/counterfactual_mean_based/estimators.jl | 10 ++++---- test/composition.jl | 4 ++-- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/counterfactual_mean_based/estimates.jl b/src/counterfactual_mean_based/estimates.jl index 53b6bfc3..5ff37f8b 100644 --- a/src/counterfactual_mean_based/estimates.jl +++ b/src/counterfactual_mean_based/estimates.jl @@ -25,21 +25,23 @@ string_repr(estimate::MLCMRelevantFactors) = string( struct TMLEstimate{T<:AbstractFloat} <: Estimate estimand::StatisticalCMCompositeEstimand - Ψ̂::T - σ::T + estimate::T + std::T + n::Int IC::Vector{T} end struct OSEstimate{T<:AbstractFloat} <: Estimate estimand::StatisticalCMCompositeEstimand - Ψ̂::T - σ̂::T + estimate::T + std::T + n::Int IC::Vector{T} end const EICEstimate = Union{TMLEstimate, OSEstimate} -without_ic(estimate::E) where E <: EICEstimate = E(estimate.estimand, estimate.Ψ̂, estimate.σ̂, []) +without_ic(estimate::E) where E <: EICEstimate = E(estimate.estimand, estimate.estimate, estimate.σ̂, []) function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate) testresult = OneSampleTTest(est) @@ -48,7 +50,7 @@ function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate) end struct ComposedEstimate{T<:AbstractFloat} <: Estimate - Ψ̂::Array{T} + estimate::Array{T} σ̂::Matrix{T} n::Int end @@ -69,7 +71,7 @@ end Retrieves the final estimate: after the TMLE step. """ -Distributions.estimate(est::EICEstimate) = est.Ψ̂ +Distributions.estimate(est::EICEstimate) = est.estimate """ Distributions.estimate(r::ComposedEstimate) @@ -77,7 +79,7 @@ Distributions.estimate(est::EICEstimate) = est.Ψ̂ Retrieves the final estimate: after the TMLE step. """ Distributions.estimate(est::ComposedEstimate) = - length(est.Ψ̂) == 1 ? est.Ψ̂[1] : est.Ψ̂ + length(est.estimate) == 1 ? est.estimate[1] : est.estimate """ var(r::EICEstimate) @@ -102,7 +104,7 @@ Statistics.var(est::ComposedEstimate) = Performs a Z test on the EICEstimate. """ HypothesisTests.OneSampleZTest(est::EICEstimate, Ψ₀=0) = - OneSampleZTest(estimate(est), std(est.IC), size(est.IC, 1), Ψ₀) + OneSampleZTest(est.estimate, est.std, est.n, Ψ₀) """ OneSampleTTest(r::EICEstimate, Ψ₀=0) @@ -110,7 +112,7 @@ HypothesisTests.OneSampleZTest(est::EICEstimate, Ψ₀=0) = Performs a T test on the EICEstimate. """ HypothesisTests.OneSampleTTest(est::EICEstimate, Ψ₀=0) = - OneSampleTTest(estimate(est), std(est.IC), size(est.IC, 1), Ψ₀) + OneSampleTTest(est.estimate, est.std, est.n, Ψ₀) """ OneSampleTTest(r::ComposedEstimate, Ψ₀=0) @@ -118,7 +120,7 @@ HypothesisTests.OneSampleTTest(est::EICEstimate, Ψ₀=0) = Performs a T test on the ComposedEstimate. """ function HypothesisTests.OneSampleTTest(est::ComposedEstimate, Ψ₀=0) - @assert length(est.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." + @assert length(est.estimate) == 1 "OneSampleTTest is only implemeted for real-valued statistics." return OneSampleTTest(estimate(est), sqrt(est.σ̂[1]), est.n, Ψ₀) end @@ -128,7 +130,7 @@ end Performs a T test on the ComposedEstimate. """ function HypothesisTests.OneSampleZTest(est::ComposedEstimate, Ψ₀=0) - @assert length(est.Ψ̂) == 1 "OneSampleTTest is only implemeted for real-valued statistics." + @assert length(est.estimate) == 1 "OneSampleTTest is only implemeted for real-valued statistics." return OneSampleZTest(estimate(est), sqrt(est.σ̂[1]), est.n, Ψ₀) end diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 5046dd88..822048be 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -187,10 +187,11 @@ function (tmle::TMLEE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict() ) # Estimation results after TMLE IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, targeted_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) - σ̂ = std(IC) / sqrt(size(IC, 1)) + σ̂ = std(IC) + n = size(IC, 1) verbosity >= 1 && @info "Done." # update!(cache, relevant_factors, targeted_factors_estimate) - return TMLEstimate(Ψ, Ψ̂, σ̂, IC), cache + return TMLEstimate(Ψ, Ψ̂, σ̂, n, IC), cache end ##################################################################### @@ -251,9 +252,10 @@ function (estimator::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dic # Gradient and estimate IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) - σ̂ = std(IC) / sqrt(size(IC, 1)) + σ̂ = std(IC) + n = size(IC, 1) verbosity >= 1 && @info "Done." - return OSEstimate(Ψ, Ψ̂ + mean(IC), σ̂, IC), cache + return OSEstimate(Ψ, Ψ̂ + mean(IC), σ̂, n, IC), cache end ##################################################################### diff --git a/test/composition.jl b/test/composition.jl index 36cd93be..47ca69a7 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -29,8 +29,8 @@ end ) n = 10 X = rand(n, 2) - ER₁ = TMLE.TMLEstimate(Ψ, 1., 1., X[:, 1]) - ER₂ = TMLE.TMLEstimate(Ψ, 0., 1., X[:, 2]) + ER₁ = TMLE.TMLEstimate(Ψ, 1., 1., n, X[:, 1]) + ER₂ = TMLE.TMLEstimate(Ψ, 0., 1., n, X[:, 2]) Σ = cov(ER₁, ER₂) @test size(Σ) == (2, 2) @test Σ == cov(X) From 1236d1586cb8a16fe2d56e5d899abf3998c9b1f5 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 17 Nov 2023 22:05:42 +0000 Subject: [PATCH 119/151] add serialization of estimates --- src/TMLE.jl | 2 +- src/counterfactual_mean_based/estimates.jl | 18 +++++++++++++- .../non_regression_test.jl | 24 +++++++++++++++---- 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index 317f32f3..f52419dd 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -32,7 +32,7 @@ export InteractionAverageTreatmentEffect, IATE export AVAILABLE_ESTIMANDS export optimize_ordering, optimize_ordering! export TMLEE, OSE, NAIVE -export var, estimate, OneSampleTTest, OneSampleZTest, pvalue, confint +export var, estimate, OneSampleTTest, OneSampleZTest, pvalue, confint, emptyIC export compose export TreatmentTransformer, with_encoder, encoder export BackdoorAdjustment, identify diff --git a/src/counterfactual_mean_based/estimates.jl b/src/counterfactual_mean_based/estimates.jl index 5ff37f8b..1c31f4b8 100644 --- a/src/counterfactual_mean_based/estimates.jl +++ b/src/counterfactual_mean_based/estimates.jl @@ -31,6 +31,8 @@ struct TMLEstimate{T<:AbstractFloat} <: Estimate IC::Vector{T} end +TMLEstimate(;estimand, estimate::T, std::T, n, IC) where T = TMLEstimate(estimand, estimate, std, n, convert(Vector{T}, IC)) + struct OSEstimate{T<:AbstractFloat} <: Estimate estimand::StatisticalCMCompositeEstimand estimate::T @@ -39,9 +41,23 @@ struct OSEstimate{T<:AbstractFloat} <: Estimate IC::Vector{T} end +OSEstimate(;estimand, estimate::T, std::T, n, IC) where T = OSEstimate(estimand, estimate, std, n, convert(Vector{T}, IC)) + const EICEstimate = Union{TMLEstimate, OSEstimate} -without_ic(estimate::E) where E <: EICEstimate = E(estimate.estimand, estimate.estimate, estimate.σ̂, []) +function to_dict(estimate::T) where T <: EICEstimate + Dict( + :type => replace(string(Base.typename(T).wrapper), "TMLE." => ""), + :estimate => estimate.estimate, + :estimand => to_dict(estimate.estimand), + :std => estimate.std, + :n => estimate.n, + :IC => estimate.IC + ) +end + +emptyIC(estimate::T) where T <: EICEstimate = + T(estimate.estimand, estimate.estimate, estimate.std, estimate.n, []) function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate) testresult = OneSampleTTest(est) diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index 749c62ec..d42477ff 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -7,6 +7,16 @@ using DataFrames using CategoricalArrays using MLJGLMInterface using MLJBase +using JSON +using YAML + +function regression_tests(tmle_result) + @test estimate(tmle_result) ≈ -0.185533 atol = 1e-6 + l, u = confint(OneSampleTTest(tmle_result)) + @test l ≈ -0.279246 atol = 1e-6 + @test u ≈ -0.091821 atol = 1e-6 + @test OneSampleZTest(tmle_result) isa OneSampleZTest +end @testset "Test ATE on perinatal dataset." begin # This is a non-regression test which was checked against the R tmle3 package @@ -39,13 +49,17 @@ using MLJBase ) tmle_result, cache = tmle(Ψ, dataset; verbosity=verbosity); + regression_tests(tmle_result) - @test estimate(tmle_result) ≈ -0.185533 atol = 1e-6 - l, u = confint(OneSampleTTest(tmle_result)) - @test l ≈ -0.279246 atol = 1e-6 - @test u ≈ -0.091821 atol = 1e-6 - @test OneSampleZTest(tmle_result) isa OneSampleZTest + configuration_to_json("result.json", [tmle_result]) + results_from_json = configuration_from_json("result.json") + regression_tests(results_from_json[1]) + rm("result.json") + configuration_to_yaml("result.yaml", [emptyIC(tmle_result)]) + results_from_yaml = configuration_from_yaml("result.yaml") + regression_tests(results_from_yaml[1]) + rm("result.yaml") # Naive naive = NAIVE(with_encoder(LinearBinaryClassifier())) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=verbosity) From c1373cf1d210a2086769917191a6c02f89e55e2a Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 17 Nov 2023 22:24:23 +0000 Subject: [PATCH 120/151] fix 1.0 tests --- src/TMLE.jl | 5 +---- .../non_regression_test.jl | 19 ++++++++++--------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index f52419dd..6594db41 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -26,11 +26,8 @@ using Combinatorics # ############################################################################# export SCM, StaticSCM, add_equations!, add_equation!, parents, vertices -export CounterfactualMean, CM -export AverageTreatmentEffect, ATE -export InteractionAverageTreatmentEffect, IATE +export CM, ATE, IATE export AVAILABLE_ESTIMANDS -export optimize_ordering, optimize_ordering! export TMLEE, OSE, NAIVE export var, estimate, OneSampleTTest, OneSampleZTest, pvalue, confint, emptyIC export compose diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index d42477ff..84e9d7a7 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -50,16 +50,17 @@ end tmle_result, cache = tmle(Ψ, dataset; verbosity=verbosity); regression_tests(tmle_result) + if VERSION >= v"1.9" + configuration_to_json("result.json", [tmle_result]) + results_from_json = configuration_from_json("result.json") + regression_tests(results_from_json[1]) + rm("result.json") - configuration_to_json("result.json", [tmle_result]) - results_from_json = configuration_from_json("result.json") - regression_tests(results_from_json[1]) - rm("result.json") - - configuration_to_yaml("result.yaml", [emptyIC(tmle_result)]) - results_from_yaml = configuration_from_yaml("result.yaml") - regression_tests(results_from_yaml[1]) - rm("result.yaml") + configuration_to_yaml("result.yaml", [emptyIC(tmle_result)]) + results_from_yaml = configuration_from_yaml("result.yaml") + regression_tests(results_from_yaml[1]) + rm("result.yaml") + end # Naive naive = NAIVE(with_encoder(LinearBinaryClassifier())) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=verbosity) From 23d5393b9fbff24528bd2d9b8a3a0a16e6ffc590 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 17 Nov 2023 22:26:17 +0000 Subject: [PATCH 121/151] drop nightly from CI --- .github/workflows/CI.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index b087f6a0..ef0ff5e0 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -12,7 +12,6 @@ jobs: version: - '1.6' - '1' - - 'nightly' os: - ubuntu-latest - macOS-latest From 5399c122b4400b218e5f1382e9b407191457202c Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 17 Nov 2023 23:34:44 +0000 Subject: [PATCH 122/151] try explicitely close the file --- ext/JSONExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/JSONExt.jl b/ext/JSONExt.jl index bdedb00b..24cd0481 100644 --- a/ext/JSONExt.jl +++ b/ext/JSONExt.jl @@ -20,9 +20,9 @@ Writes a `Configuration` struct to a YAML file. The latter can be deserialized with `configuration_from_yaml`. """ function TMLE.configuration_to_json(file, config; indent=1) - open(file, "w") do io - JSON.print(io, to_dict(config), indent) - end + io = open(file, "w") + JSON.print(io, to_dict(config), indent) + close(io) end end \ No newline at end of file From 61ba7e05040f0f4c9973758a0e330806f5644f1a Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sat, 18 Nov 2023 00:16:48 +0000 Subject: [PATCH 123/151] try something else --- ext/JSONExt.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ext/JSONExt.jl b/ext/JSONExt.jl index 24cd0481..2cd2fb33 100644 --- a/ext/JSONExt.jl +++ b/ext/JSONExt.jl @@ -20,9 +20,9 @@ Writes a `Configuration` struct to a YAML file. The latter can be deserialized with `configuration_from_yaml`. """ function TMLE.configuration_to_json(file, config; indent=1) - io = open(file, "w") - JSON.print(io, to_dict(config), indent) - close(io) + open(file, "w") do io + write(io, JSON.json(to_dict(config), indent)) + end end end \ No newline at end of file From dc99bb34049e48a349dbc501f1f91de6c7ed42fe Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sat, 18 Nov 2023 10:47:20 +0000 Subject: [PATCH 124/151] fix bug in groups ordering --- src/estimand_ordering.jl | 10 ++++++++-- test/estimand_ordering.jl | 1 + 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/estimand_ordering.jl b/src/estimand_ordering.jl index 928feac2..a243243e 100644 --- a/src/estimand_ordering.jl +++ b/src/estimand_ordering.jl @@ -109,9 +109,15 @@ if a minimum is found fast it is immediatly returned. The theoretical complexity is in O(N!). However due to the stop fast approach and the shuffling, this is actually expected to be much smaller than that. """ -function brute_force_ordering(estimands; permutation_generator = estimands_permutation_generator(estimands), η_counts=nuisance_counts(estimands), do_shuffle=true, rng=Random.default_rng(), verbosity=0) - optimal_ordering = estimands +function brute_force_ordering(estimands; + permutation_generator = estimands_permutation_generator(estimands), + η_counts=nuisance_counts(estimands), + do_shuffle=true, + rng=Random.default_rng(), + verbosity=0 + ) estimands = do_shuffle ? shuffle(rng, estimands) : estimands + optimal_ordering = estimands min_maxmem_lowerbound = get_min_maxmem_lowerbound(estimands) optimal_maxmem, _ = evaluate_proxy_costs(estimands, η_counts) for perm ∈ permutation_generator diff --git a/test/estimand_ordering.jl b/test/estimand_ordering.jl index d1a924c4..49084266 100644 --- a/test/estimand_ordering.jl +++ b/test/estimand_ordering.jl @@ -83,6 +83,7 @@ statistical_estimands = [identify(x, scm) for x in causal_estimands] ordering_from_groups = groups_ordering(bad_ordering) @test TMLE.evaluate_proxy_costs(ordering_from_groups, η_counts) == (4, 9) # Adding a layer of brute forcing results in an optimal ordering + ordering_from_groups_with_brute_force = groups_ordering(bad_ordering, brute_force=true) @test TMLE.evaluate_proxy_costs(ordering_from_groups_with_brute_force, η_counts) == (3, 9) end From 32456132e377fe7e684489aa4cb2f98fbce8ebf2 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Sun, 19 Nov 2023 00:25:54 +0000 Subject: [PATCH 125/151] try add + to read --- ext/JSONExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/JSONExt.jl b/ext/JSONExt.jl index 2cd2fb33..6183f8bb 100644 --- a/ext/JSONExt.jl +++ b/ext/JSONExt.jl @@ -20,7 +20,7 @@ Writes a `Configuration` struct to a YAML file. The latter can be deserialized with `configuration_from_yaml`. """ function TMLE.configuration_to_json(file, config; indent=1) - open(file, "w") do io + open(file, "w+") do io write(io, JSON.json(to_dict(config), indent)) end end From bda3feceea7982cb4999331309b255c3fc739b1c Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 20 Nov 2023 10:07:55 +0000 Subject: [PATCH 126/151] change read method's name --- ext/JSONExt.jl | 10 ++++---- ext/YAMLExt.jl | 10 ++++---- src/TMLE.jl | 4 +-- src/configuration.jl | 8 +++--- test/configuration.jl | 25 +++++++++---------- .../non_regression_test.jl | 8 +++--- 6 files changed, 32 insertions(+), 33 deletions(-) diff --git a/ext/JSONExt.jl b/ext/JSONExt.jl index 6183f8bb..b50019bb 100644 --- a/ext/JSONExt.jl +++ b/ext/JSONExt.jl @@ -4,22 +4,22 @@ using JSON using TMLE """ - configuration_from_json(file) + read_json(file) Loads a YAML configuration file containing: - Estimands - An SCM (optional) - An Adjustment Method (optional) """ -TMLE.configuration_from_json(file) = from_dict!(JSON.parsefile(file, dicttype=Dict{Symbol, Any})) +TMLE.read_json(file) = from_dict!(JSON.parsefile(file, dicttype=Dict{Symbol, Any})) """ - configuration_to_json(file, config::Configuration) + write_json(file, config::Configuration) Writes a `Configuration` struct to a YAML file. The latter can be deserialized -with `configuration_from_yaml`. +with `read_yaml`. """ -function TMLE.configuration_to_json(file, config; indent=1) +function TMLE.write_json(file, config; indent=1) open(file, "w+") do io write(io, JSON.json(to_dict(config), indent)) end diff --git a/ext/YAMLExt.jl b/ext/YAMLExt.jl index d4d97922..86ab8a3e 100644 --- a/ext/YAMLExt.jl +++ b/ext/YAMLExt.jl @@ -4,21 +4,21 @@ using YAML using TMLE """ - configuration_from_yaml(file) + read_yaml(file) Loads a YAML configuration file containing: - Estimands - An SCM (optional) - An Adjustment Method (optional) """ -TMLE.configuration_from_yaml(file) = from_dict!(YAML.load_file(file, dicttype=Dict{Symbol, Any})) +TMLE.read_yaml(file) = from_dict!(YAML.load_file(file, dicttype=Dict{Symbol, Any})) """ - configuration_to_yaml(file, config::Configuration) + write_yaml(file, config::Configuration) Writes a `Configuration` struct to a YAML file. The latter can be deserialized -with `configuration_from_yaml`. +with `read_yaml`. """ -TMLE.configuration_to_yaml(file, config) = YAML.write_file(file, to_dict(config)) +TMLE.write_yaml(file, config) = YAML.write_file(file, to_dict(config)) end \ No newline at end of file diff --git a/src/TMLE.jl b/src/TMLE.jl index 6594db41..e4101e40 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -34,8 +34,8 @@ export compose export TreatmentTransformer, with_encoder, encoder export BackdoorAdjustment, identify export last_fluctuation_epsilon -export configuration_from_yaml, configuration_to_yaml -export configuration_from_json, configuration_to_json +export read_yaml, write_yaml +export read_json, write_json export to_dict, from_dict!, Configuration export brute_force_ordering, groups_ordering diff --git a/src/configuration.jl b/src/configuration.jl index dac88542..155c5c5c 100644 --- a/src/configuration.jl +++ b/src/configuration.jl @@ -39,7 +39,7 @@ function from_dict!(d::Dict{T, Any}) where T return constructor(;[key => from_dict!(val) for (key, val) in d]...) end -function configuration_from_yaml end -function configuration_to_yaml end -function configuration_from_json end -function configuration_to_json end \ No newline at end of file +function read_yaml end +function write_yaml end +function read_json end +function write_json end \ No newline at end of file diff --git a/test/configuration.jl b/test/configuration.jl index 4721d9f1..d76e514c 100644 --- a/test/configuration.jl +++ b/test/configuration.jl @@ -52,12 +52,12 @@ jsonfilename = string(outprefix, "json") treatment_confounders=(T3 = [:W], T1 = [:W]) ) ] - configuration_to_yaml(yamlfilename, estimands) - estimands_from_yaml = configuration_from_yaml(yamlfilename) + write_yaml(yamlfilename, estimands) + estimands_from_yaml = read_yaml(yamlfilename) check_estimands(estimands_from_yaml, estimands) rm(yamlfilename) - configuration_to_json(jsonfilename, estimands) - estimands_from_json = configuration_from_json(jsonfilename) + write_json(jsonfilename, estimands) + estimands_from_json = read_json(jsonfilename) check_estimands(estimands_from_json, estimands) rm(jsonfilename) @@ -65,10 +65,10 @@ jsonfilename = string(outprefix, "json") estimands=estimands ) config_dict = to_dict(configuration) - configuration_to_yaml(yamlfilename, configuration) - configuration_to_json(jsonfilename, configuration) - config_from_yaml = configuration_from_yaml(yamlfilename) - config_from_json = configuration_from_json(jsonfilename) + write_yaml(yamlfilename, configuration) + write_json(jsonfilename, configuration) + config_from_yaml = read_yaml(yamlfilename) + config_from_json = read_json(jsonfilename) for loaded_config in (config_from_yaml, config_from_json) @test loaded_config.scm === nothing @test loaded_config.adjustment === nothing @@ -77,7 +77,6 @@ jsonfilename = string(outprefix, "json") end rm(yamlfilename) rm(jsonfilename) - end @testset "Test with an SCM and causal estimands" begin @@ -111,11 +110,11 @@ end treatments = [:T1, :T2]), adjustment=BackdoorAdjustment([:C1, :C2]) ) - configuration_to_yaml(yamlfilename, configuration) - configuration_to_json(jsonfilename, configuration) + write_yaml(yamlfilename, configuration) + write_json(jsonfilename, configuration) - config_from_yaml = configuration_from_yaml(yamlfilename) - config_from_json = configuration_from_json(jsonfilename) + config_from_yaml = read_yaml(yamlfilename) + config_from_json = read_json(jsonfilename) for loaded_config in (config_from_yaml, config_from_json) scm = loaded_config.scm diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index 84e9d7a7..63c0f31b 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -51,13 +51,13 @@ end tmle_result, cache = tmle(Ψ, dataset; verbosity=verbosity); regression_tests(tmle_result) if VERSION >= v"1.9" - configuration_to_json("result.json", [tmle_result]) - results_from_json = configuration_from_json("result.json") + write_json("result.json", [tmle_result]) + results_from_json = read_json("result.json") regression_tests(results_from_json[1]) rm("result.json") - configuration_to_yaml("result.yaml", [emptyIC(tmle_result)]) - results_from_yaml = configuration_from_yaml("result.yaml") + write_yaml("result.yaml", [emptyIC(tmle_result)]) + results_from_yaml = read_yaml("result.yaml") regression_tests(results_from_yaml[1]) rm("result.yaml") end From d02633e0a6eabb24cab13856eb81b2b7e5a2f66f Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 21 Nov 2023 15:34:00 +0000 Subject: [PATCH 127/151] add FitFailedError --- src/configuration.jl | 2 +- src/counterfactual_mean_based/estimators.jl | 45 +++++++++++++------ .../estimators_and_estimates.jl | 35 ++++++++++++++- .../non_regression_test.jl | 1 - 4 files changed, 67 insertions(+), 16 deletions(-) diff --git a/src/configuration.jl b/src/configuration.jl index 155c5c5c..09b1d95b 100644 --- a/src/configuration.jl +++ b/src/configuration.jl @@ -34,7 +34,7 @@ from_dict!(v::AbstractVector) = [from_dict!(x) for x in v] Converts a dictionary to a TMLE struct. """ function from_dict!(d::Dict{T, Any}) where T - haskey(d, T(:type)) || return d + haskey(d, T(:type)) || return Dict(key => from_dict!(val) for (key, val) in d) constructor = eval(Meta.parse(pop!(d, :type))) return constructor(;[key => from_dict!(val) for (key, val) in d]...) end diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 822048be..b4210769 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -1,6 +1,18 @@ ##################################################################### ### CMRelevantFactorsEstimator ### ##################################################################### +struct FitFailedError <: Exception + msg::String + origin::Exception +end + +propensity_score_fit_error_msg(factor) = string("Could not fit the following propensity score model: ", string_repr(factor)) +outcome_mean_fit_error_msg(factor) = string( + "Could not fit the following Outcome mean model: ", + string_repr(factor), + ".\n Hint: don't forget to use `with_encoder` to encode categorical variables.") + +Base.showerror(io::IO, e::FitFailedError) = print(io, e.msg) struct CMRelevantFactorsEstimator <: Estimator resampling::Union{Nothing, ResamplingStrategy} @@ -52,23 +64,30 @@ function (estimator::CMRelevantFactorsEstimator)(estimand, dataset; cache=Dict() # Get train validation indices train_validation_indices = TMLE.get_train_validation_indices(estimator.resampling, estimand, dataset) # Fit propensity score - propensity_score_estimate = Tuple( - ConditionalDistributionEstimator(train_validation_indices, acquire_model(models, factor.outcome, dataset, true))( - factor, - dataset; - cache=cache, - verbosity=verbosity, - machine_cache=machine_cache - ) - for factor in estimand.propensity_score - ) + propensity_score_estimate = map(estimand.propensity_score) do factor + try + ConditionalDistributionEstimator(train_validation_indices, acquire_model(models, factor.outcome, dataset, true))( + factor, + dataset; + cache=cache, + verbosity=verbosity, + machine_cache=machine_cache + ) + catch e + throw(FitFailedError(propensity_score_fit_error_msg(factor), e)) + end + end # Fit outcome mean outcome_mean = estimand.outcome_mean model = acquire_model(models, outcome_mean.outcome, dataset, false) - outcome_mean_estimate = ConditionalDistributionEstimator( - train_validation_indices, - model + outcome_mean_estimate = try + ConditionalDistributionEstimator( + train_validation_indices, + model )(outcome_mean, dataset; cache=cache, verbosity=verbosity, machine_cache=machine_cache) + catch e + throw(FitFailedError(outcome_mean_fit_error_msg(outcome_mean), e)) + end # Build estimate estimate = MLCMRelevantFactors(estimand, outcome_mean_estimate, propensity_score_estimate) # Update cache diff --git a/test/counterfactual_mean_based/estimators_and_estimates.jl b/test/counterfactual_mean_based/estimators_and_estimates.jl index dcd80468..c8600e09 100644 --- a/test/counterfactual_mean_based/estimators_and_estimates.jl +++ b/test/counterfactual_mean_based/estimators_and_estimates.jl @@ -11,12 +11,17 @@ using CategoricalArrays using LogExpFunctions using MLJBase -@testset "Test CMRelevantFactorsEstimator" begin +function make_dataset() n = 100 W = rand(Normal(), n) T₁ = rand(n) .< logistic.(1 .- W) Y = T₁ .+ W .+ rand(n) dataset = DataFrame(Y=Y, W=W, T₁=categorical(T₁, ordered=true)) + return dataset +end + +@testset "Test CMRelevantFactorsEstimator" begin + dataset = make_dataset() # Estimand Q = TMLE.ConditionalDistribution(:Y, [:T₁, :W]) G = (TMLE.ConditionalDistribution(:T₁, [:W]),) @@ -74,6 +79,34 @@ using MLJBase @test η̂ₙ.outcome_mean.train_validation_indices == η̂ₙ.propensity_score[1].train_validation_indices end +@testset "Test FitFailedError" begin + dataset = make_dataset() + # Estimand + Q = TMLE.ConditionalDistribution(:Y, [:T₁, :W]) + G = (TMLE.ConditionalDistribution(:T₁, [:W]),) + η = TMLE.CMRelevantFactors(outcome_mean=Q, propensity_score=G) + # Estimator + models = ( + Y = with_encoder(LinearRegressor()), + T₁ = LinearRegressor() + ) + η̂ = TMLE.CMRelevantFactorsEstimator(models=models) + try + η̂(η, dataset; verbosity=0) + @test true === false + catch e + @test e isa TMLE.FitFailedError + @test e.msg == TMLE.propensity_score_fit_error_msg(G[1]) + end + + models = ( + Y = LogisticClassifier(), + T₁ = LogisticClassifier(fit_intercept=false) + ) + η̂ = TMLE.CMRelevantFactorsEstimator(models=models) + @test_throws TMLE.FitFailedError η̂(η, dataset; verbosity=0) +end + @testset "Test structs are concrete types" begin for type in (OSE, TMLEE, NAIVE) @test isconcretetype(type) diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index 63c0f31b..240a8f4a 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -65,7 +65,6 @@ end naive = NAIVE(with_encoder(LinearBinaryClassifier())) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=verbosity) @test naive_result ≈ -0.150078 atol = 1e-6 - end end From a2cf0b96c4d1011534b5f1c91697823d7e87635e Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 21 Nov 2023 15:37:36 +0000 Subject: [PATCH 128/151] add estimand to FitFailed exception --- src/counterfactual_mean_based/estimators.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index b4210769..71c814df 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -2,6 +2,7 @@ ### CMRelevantFactorsEstimator ### ##################################################################### struct FitFailedError <: Exception + estimand::Estimand msg::String origin::Exception end @@ -74,7 +75,7 @@ function (estimator::CMRelevantFactorsEstimator)(estimand, dataset; cache=Dict() machine_cache=machine_cache ) catch e - throw(FitFailedError(propensity_score_fit_error_msg(factor), e)) + throw(FitFailedError(factor, propensity_score_fit_error_msg(factor), e)) end end # Fit outcome mean @@ -86,7 +87,7 @@ function (estimator::CMRelevantFactorsEstimator)(estimand, dataset; cache=Dict() model )(outcome_mean, dataset; cache=cache, verbosity=verbosity, machine_cache=machine_cache) catch e - throw(FitFailedError(outcome_mean_fit_error_msg(outcome_mean), e)) + throw(FitFailedError(outcome_mean, outcome_mean_fit_error_msg(outcome_mean), e)) end # Build estimate estimate = MLCMRelevantFactors(estimand, outcome_mean_estimate, propensity_score_estimate) From ed9f899f513d1bee8580cf05995a043d0a5bb1ce Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 22 Nov 2023 14:00:38 +0000 Subject: [PATCH 129/151] enforce sorted tuples to make estimands equal --- src/counterfactual_mean_based/estimands.jl | 16 ++++--- test/configuration.jl | 20 +++------ test/counterfactual_mean_based/estimands.jl | 49 +++++++++++++++++++-- 3 files changed, 61 insertions(+), 24 deletions(-) diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index c2acae2e..cf53c876 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -144,13 +144,17 @@ treatment_specs_to_dict(treatment_values::NamedTuple) = Dict(pairs(treatment_val treatment_values(d::AbstractDict) = (;d...) treatment_values(d) = d -function treatment_specs_from_dict(dict::AbstractDict) - treatment_variables = keys(dict) - return NamedTuple{Tuple(treatment_variables)}([treatment_values(val) for val in values(dict)]) -end +get_treatment_specs(treatment_specs::NamedTuple{names, }) where names = + NamedTuple{sort(names)}(treatment_specs) -get_treatment_specs(treatment_values::NamedTuple) = treatment_values -get_treatment_specs(treatment_values::AbstractDict) = treatment_specs_from_dict(treatment_values) +function get_treatment_specs(treatment_specs::NamedTuple{names, <:Tuple{Vararg{NamedTuple}}}) where names + case_control = ((case=v[:case], control=v[:control]) for v in values(treatment_specs)) + treatment_specs = (;zip(keys(treatment_specs), case_control)...) + return NamedTuple{sort(names)}(treatment_specs) +end + +get_treatment_specs(treatment_specs::AbstractDict) = + get_treatment_specs((;(key => treatment_values(val) for (key, val) in treatment_specs)...)) constructorname(T; prefix="TMLE.Causal") = replace(string(T), prefix => "") diff --git a/test/configuration.jl b/test/configuration.jl index d76e514c..fbda0c4d 100644 --- a/test/configuration.jl +++ b/test/configuration.jl @@ -54,11 +54,11 @@ jsonfilename = string(outprefix, "json") ] write_yaml(yamlfilename, estimands) estimands_from_yaml = read_yaml(yamlfilename) - check_estimands(estimands_from_yaml, estimands) + @test estimands_from_yaml == estimands rm(yamlfilename) write_json(jsonfilename, estimands) estimands_from_json = read_json(jsonfilename) - check_estimands(estimands_from_json, estimands) + @test estimands_from_json == estimands rm(jsonfilename) configuration = Configuration( @@ -73,7 +73,7 @@ jsonfilename = string(outprefix, "json") @test loaded_config.scm === nothing @test loaded_config.adjustment === nothing reconstructed_estimands = loaded_config.estimands - check_estimands(reconstructed_estimands, estimands) + @test reconstructed_estimands == estimands end rm(yamlfilename) rm(jsonfilename) @@ -121,20 +121,12 @@ end adjustment = loaded_config.adjustment # Estimand 1 estimand₁ = loaded_config.estimands[1] + @test estimand₁ == estimands[1] statistical_estimand₁ = identify(adjustment, estimand₁, scm) - @test statistical_estimand₁.outcome == :Y1 - @test statistical_estimand₁.treatment_values.T1 == (case=1, control=0) - @test statistical_estimand₁.treatment_values.T2 == (case=1, control=0) - @test statistical_estimand₁.treatment_confounders.T2 == (:W,) - @test statistical_estimand₁.treatment_confounders.T1 == (:W,) - @test statistical_estimand₁.outcome_extra_covariates == (:C1, :C2) + @test identify(adjustment, estimand₁, scm) == identify(configuration.adjustment, estimands[1], configuration.scm) # Estimand 2 estimand₂ = loaded_config.estimands[2] - @test estimand₂.outcome == :Y3 - @test estimand₂.treatment_values.T1 == (case = 1, control = 0) - @test estimand₂.treatment_values.T3 == (case = "AC", control = "CC") - @test estimand₂.treatment_confounders.T1 == (:W,) - @test estimand₂.treatment_confounders.T3 == (:W,) + @test estimand₂ == estimands[2] end rm(yamlfilename) rm(jsonfilename) diff --git a/test/counterfactual_mean_based/estimands.jl b/test/counterfactual_mean_based/estimands.jl index 4392a7c1..0e2d6da0 100644 --- a/test/counterfactual_mean_based/estimands.jl +++ b/test/counterfactual_mean_based/estimands.jl @@ -80,6 +80,50 @@ using TMLE @test indic_values == [-1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0] end +@testset "Test standardization of StatisticalCMCompositeEstimand" begin + # ATE + Ψ₁ = ATE( + outcome="Y", + treatment_values = (T₁=(case=1, control=0), T₂=(control=1, case=0)), + treatment_confounders = ( + T₂ = [:W₂, "W₁"], + T₁ = ["W₀"] + ), + outcome_extra_covariates = ["Z", :A] + ) + Ψ₂ = ATE( + outcome=:Y, + treatment_values = (T₁=(case=1, control=0), T₂=(case=0, control=1)), + treatment_confounders = ( + T₁ = [:W₀], + T₂ = [:W₂, :W₁] + ), + outcome_extra_covariates = [:A, :Z] + ) + @test Ψ₁ == Ψ₂ + @test Ψ₁.outcome == :Y + @test Ψ₁.treatment_values == (T₁=(case=1, control=0), T₂=(case=0, control=1)) + @test Ψ₁.treatment_confounders == (T₁ = (:W₀,), T₂ = (:W₁, :W₂)) + @test Ψ₁.outcome_extra_covariates == (:A, :Z) + + # CM + Ψ₁ = CM( + outcome=:Y, + treatment_values=(T₁=1, T₂=0), + treatment_confounders = (T₂ = ["W₁", "W₂"], T₁ = [:W₀]), + ) + Ψ₂ = CM( + outcome=:Y, + treatment_values=(T₂=0, T₁=1), + treatment_confounders = (T₂ = [:W₂, "W₁"],T₁ = ["W₀"]), + ) + @test Ψ₁ == Ψ₂ + @test Ψ₁.outcome == :Y + @test Ψ₁.treatment_values == (T₁=1, T₂=0) + @test Ψ₁.treatment_confounders == (T₁ = (:W₀,), T₂ = (:W₁, :W₂)) + @test Ψ₁.outcome_extra_covariates == () +end + @testset "Test structs are concrete types" begin for type in Base.uniontypes(TMLE.StatisticalCMCompositeEstimand) @test isconcretetype(type) @@ -120,10 +164,7 @@ end :outcome => :y ) Ψreconstructed = from_dict!(d) - @test Ψreconstructed.outcome == Ψ.outcome - @test Ψreconstructed.treatment_values.T₁ == Ψ.treatment_values.T₁ - @test Ψreconstructed.treatment_values.T₂ == Ψ.treatment_values.T₂ - @test Ψreconstructed.treatment_values.T₃ == Ψ.treatment_values.T₃ + @test Ψ == Ψ # Statistical CM Ψ = CM( From 28e3753e2225c54423ee8680625942236d54584b Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 22 Nov 2023 16:09:46 +0000 Subject: [PATCH 130/151] add ComposedEstimand and associated estimator --- src/TMLE.jl | 1 + src/counterfactual_mean_based/estimates.jl | 113 --------------------- src/estimands.jl | 11 +- src/estimates.jl | 61 +++++++++++ src/estimators.jl | 85 ++++++++++++++++ test/composition.jl | 12 ++- 6 files changed, 166 insertions(+), 117 deletions(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index e4101e40..2ae0fec5 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -29,6 +29,7 @@ export SCM, StaticSCM, add_equations!, add_equation!, parents, vertices export CM, ATE, IATE export AVAILABLE_ESTIMANDS export TMLEE, OSE, NAIVE +export ComposedEstimand export var, estimate, OneSampleTTest, OneSampleZTest, pvalue, confint, emptyIC export compose export TreatmentTransformer, with_encoder, encoder diff --git a/src/counterfactual_mean_based/estimates.jl b/src/counterfactual_mean_based/estimates.jl index 1c31f4b8..8acb26ed 100644 --- a/src/counterfactual_mean_based/estimates.jl +++ b/src/counterfactual_mean_based/estimates.jl @@ -32,7 +32,6 @@ struct TMLEstimate{T<:AbstractFloat} <: Estimate end TMLEstimate(;estimand, estimate::T, std::T, n, IC) where T = TMLEstimate(estimand, estimate, std, n, convert(Vector{T}, IC)) - struct OSEstimate{T<:AbstractFloat} <: Estimate estimand::StatisticalCMCompositeEstimand estimate::T @@ -65,23 +64,6 @@ function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate) pretty_table(io, data;header=["Estimate", "95% Confidence Interval", "P-value"]) end -struct ComposedEstimate{T<:AbstractFloat} <: Estimate - estimate::Array{T} - σ̂::Matrix{T} - n::Int -end - -function Base.show(io::IO, ::MIME"text/plain", est::ComposedEstimate) - if length(est.σ̂) !== 1 - println(io, string("Estimate: ", estimate(est), "\nVariance: \n", var(est))) - else - testresult = OneSampleTTest(est) - data = [estimate(est) confint(testresult) pvalue(testresult);] - headers = ["Estimate", "95% Confidence Interval", "P-value"] - pretty_table(io, data;header=headers) - end -end - """ Distributions.estimate(r::EICEstimate) @@ -89,14 +71,6 @@ Retrieves the final estimate: after the TMLE step. """ Distributions.estimate(est::EICEstimate) = est.estimate -""" - Distributions.estimate(r::ComposedEstimate) - -Retrieves the final estimate: after the TMLE step. -""" -Distributions.estimate(est::ComposedEstimate) = - length(est.estimate) == 1 ? est.estimate[1] : est.estimate - """ var(r::EICEstimate) @@ -105,14 +79,6 @@ Computes the estimated variance associated with the estimate. Statistics.var(est::EICEstimate) = var(est.IC)/size(est.IC, 1) -""" - var(r::ComposedEstimate) - -Computes the estimated variance associated with the estimate. -""" -Statistics.var(est::ComposedEstimate) = - length(est.σ̂) == 1 ? est.σ̂[1] / est.n : est.σ̂ ./ est.n - """ OneSampleZTest(r::EICEstimate, Ψ₀=0) @@ -130,82 +96,3 @@ Performs a T test on the EICEstimate. HypothesisTests.OneSampleTTest(est::EICEstimate, Ψ₀=0) = OneSampleTTest(est.estimate, est.std, est.n, Ψ₀) -""" - OneSampleTTest(r::ComposedEstimate, Ψ₀=0) - -Performs a T test on the ComposedEstimate. -""" -function HypothesisTests.OneSampleTTest(est::ComposedEstimate, Ψ₀=0) - @assert length(est.estimate) == 1 "OneSampleTTest is only implemeted for real-valued statistics." - return OneSampleTTest(estimate(est), sqrt(est.σ̂[1]), est.n, Ψ₀) -end - -""" - OneSampleZTest(r::ComposedEstimate, Ψ₀=0) - -Performs a T test on the ComposedEstimate. -""" -function HypothesisTests.OneSampleZTest(est::ComposedEstimate, Ψ₀=0) - @assert length(est.estimate) == 1 "OneSampleTTest is only implemeted for real-valued statistics." - return OneSampleZTest(estimate(est), sqrt(est.σ̂[1]), est.n, Ψ₀) -end - -""" - compose(f, estimation_results::Vararg{EICEstimate, N}) where N - -Provides an estimator of f(estimation_results...). - -# Mathematical details - -The following is a summary from `Asymptotic Statistics`, A. W. van der Vaart. - -Consider k TMLEs computed from a dataset of size n and embodied by Tₙ = (T₁,ₙ, ..., Tₖ,ₙ). -Since each of them is asymptotically normal, the multivariate CLT provides the joint -distribution: - - √n(Tₙ - Ψ₀) ↝ N(0, Σ), - -where Σ is the covariance matrix of the TMLEs influence curves. - -Let f:ℜᵏ→ℜᵐ, be a differentiable map at Ψ₀. Then, the delta method provides the -limiting distribution of √n(f(Tₙ) - f(Ψ₀)). Because Tₙ is normal, the result is: - - √n(f(Tₙ) - f(Ψ₀)) ↝ N(0, ∇f(Ψ₀) ̇Σ ̇(∇f(Ψ₀))ᵀ), - -where ∇f(Ψ₀):ℜᵏ→ℜᵐ is a linear map such that by abusing notations and identifying the -function with the multiplication matrix: ∇f(Ψ₀):h ↦ ∇f(Ψ₀) ̇h. And the matrix ∇f(Ψ₀) is -the jacobian of f at Ψ₀. - -Hence, the only thing we need to do is: -- Compute the covariance matrix Σ -- Compute the jacobian ∇f, which can be done using Julia's automatic differentiation facilities. -- The final estimator is normal with mean f₀=f(Ψ₀) and variance σ₀=∇f(Ψ₀) ̇Σ ̇(∇f(Ψ₀))ᵀ - -# Arguments - -- f: An array-input differentiable map. -- estimation_results: 1 or more `EICEstimate` structs. - -# Examples - -Assuming `res₁` and `res₂` are TMLEs: - -```julia -f(x, y) = [x^2 - y, y - 3x] -compose(f, res₁, res₂) -``` -""" -function compose(f, estimators::Vararg{EICEstimate, N}; backend=AD.ZygoteBackend()) where N - Σ = cov(estimators...) - estimates = [estimate(r) for r in estimators] - f₀, Js = AD.value_and_jacobian(backend, f, estimates...) - J = hcat(Js...) - n = size(first(estimators).IC, 1) - σ₀ = J*Σ*J' - return ComposedEstimate(collect(f₀), σ₀, n) -end - -function Statistics.cov(estimators::Vararg{EICEstimate, N}) where N - X = hcat([r.IC for r in estimators]...) - return Statistics.cov(X, dims=1, corrected=true) -end diff --git a/src/estimands.jl b/src/estimands.jl index 2bb4e05c..af3b4290 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -100,4 +100,13 @@ Defines an Expected Value estimand ``parents → E[Outcome|parents]``. At the moment there is no distinction between an Expected Value and a Conditional Distribution because they are estimated in the same way. """ -const ExpectedValue = ConditionalDistribution \ No newline at end of file +const ExpectedValue = ConditionalDistribution + +##################################################################### +### ComposedEstimand ### +##################################################################### + +struct ComposedEstimand <: Estimand + f::Function + args::Tuple +end \ No newline at end of file diff --git a/src/estimates.jl b/src/estimates.jl index f2b79f9a..7c071a1d 100644 --- a/src/estimates.jl +++ b/src/estimates.jl @@ -90,3 +90,64 @@ function compute_offset(estimate::ConditionalDistributionEstimate, X) ŷ = predict(estimate, X) return compute_offset(ŷ) end + +##################################################################### +### Composed Estimate ### +##################################################################### + +struct ComposedEstimate{T<:AbstractFloat} <: Estimate + estimand::ComposedEstimand + estimates::Tuple + estimate::Array{T} + std::Matrix{T} + n::Int +end + +function Base.show(io::IO, ::MIME"text/plain", est::ComposedEstimate) + if length(est.std) !== 1 + println(io, string("Estimate: ", estimate(est), "\nVariance: \n", var(est))) + else + testresult = OneSampleTTest(est) + data = [estimate(est) confint(testresult) pvalue(testresult);] + headers = ["Estimate", "95% Confidence Interval", "P-value"] + pretty_table(io, data;header=headers) + end +end + +""" + Distributions.estimate(r::ComposedEstimate) + +Retrieves the final estimate: after the TMLE step. +""" +Distributions.estimate(est::ComposedEstimate) = + length(est.estimate) == 1 ? est.estimate[1] : est.estimate + + +""" + var(r::ComposedEstimate) + +Computes the estimated variance associated with the estimate. +""" +Statistics.var(est::ComposedEstimate) = + length(est.std) == 1 ? est.std[1] / est.n : est.std ./ est.n + + +""" + OneSampleTTest(r::ComposedEstimate, Ψ₀=0) + +Performs a T test on the ComposedEstimate. +""" +function HypothesisTests.OneSampleTTest(est::ComposedEstimate, Ψ₀=0) + @assert length(est.estimate) == 1 "OneSampleTTest is only implemeted for real-valued statistics." + return OneSampleTTest(estimate(est), sqrt(est.std[1]), est.n, Ψ₀) +end + +""" + OneSampleZTest(r::ComposedEstimate, Ψ₀=0) + +Performs a T test on the ComposedEstimate. +""" +function HypothesisTests.OneSampleZTest(est::ComposedEstimate, Ψ₀=0) + @assert length(est.estimate) == 1 "OneSampleTTest is only implemeted for real-valued statistics." + return OneSampleZTest(estimate(est), sqrt(est.std[1]), est.n, Ψ₀) +end diff --git a/src/estimators.jl b/src/estimators.jl index 933b7858..da8ddc31 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -85,3 +85,88 @@ ConditionalDistributionEstimator(train_validation_indices::Nothing, model) = ConditionalDistributionEstimator(train_validation_indices, model) = SampleSplitMLConditionalDistributionEstimator(model, train_validation_indices) + +##################################################################### +### ComposedEstimand Estimator ### +##################################################################### + +""" + (estimator::Estimator)(Ψ::ComposedEstimand, dataset; cache=Dict(), verbosity=1) + +Estimates all components of Ψ and then Ψ itself. +""" +function (estimator::Estimator)(Ψ::ComposedEstimand, dataset; cache=Dict(), verbosity=1, backend=AD.ZygoteBackend()) + estimates = map(Ψ.args) do estimand + estimate, cache = estimator(estimand, dataset; cache=cache, verbosity=verbosity) + estimate + end + f₀, σ₀, n = _compose(Ψ.f, estimates...; backend=backend) + return ComposedEstimate(Ψ, estimates, f₀, σ₀, n), cache +end + + +""" + compose(f, estimation_results::Vararg{EICEstimate, N}) where N + +Provides an estimator of f(estimation_results...). + +# Mathematical details + +The following is a summary from `Asymptotic Statistics`, A. W. van der Vaart. + +Consider k TMLEs computed from a dataset of size n and embodied by Tₙ = (T₁,ₙ, ..., Tₖ,ₙ). +Since each of them is asymptotically normal, the multivariate CLT provides the joint +distribution: + + √n(Tₙ - Ψ₀) ↝ N(0, Σ), + +where Σ is the covariance matrix of the TMLEs influence curves. + +Let f:ℜᵏ→ℜᵐ, be a differentiable map at Ψ₀. Then, the delta method provides the +limiting distribution of √n(f(Tₙ) - f(Ψ₀)). Because Tₙ is normal, the result is: + + √n(f(Tₙ) - f(Ψ₀)) ↝ N(0, ∇f(Ψ₀) ̇Σ ̇(∇f(Ψ₀))ᵀ), + +where ∇f(Ψ₀):ℜᵏ→ℜᵐ is a linear map such that by abusing notations and identifying the +function with the multiplication matrix: ∇f(Ψ₀):h ↦ ∇f(Ψ₀) ̇h. And the matrix ∇f(Ψ₀) is +the jacobian of f at Ψ₀. + +Hence, the only thing we need to do is: +- Compute the covariance matrix Σ +- Compute the jacobian ∇f, which can be done using Julia's automatic differentiation facilities. +- The final estimator is normal with mean f₀=f(Ψ₀) and variance σ₀=∇f(Ψ₀) ̇Σ ̇(∇f(Ψ₀))ᵀ + +# Arguments + +- f: An array-input differentiable map. +- estimation_results: 1 or more `EICEstimate` structs. + +# Examples + +Assuming `res₁` and `res₂` are TMLEs: + +```julia +f(x, y) = [x^2 - y, y - 3x] +compose(f, res₁, res₂) +``` +""" +function compose(f, estimates...; backend=AD.ZygoteBackend()) + f₀, σ₀, n = _compose(f, estimates...; backend=backend) + estimand = ComposedEstimand(f, Tuple(e.estimand for e in estimates)) + return ComposedEstimate(estimand, estimates, f₀, σ₀, n) +end + +function _compose(f, estimates...; backend=AD.ZygoteBackend()) + Σ = cov(estimates...) + point_estimates = [r.estimate for r in estimates] + f₀, Js = AD.value_and_jacobian(backend, f, point_estimates...) + J = hcat(Js...) + n = size(first(estimates).IC, 1) + σ₀ = J*Σ*J' + return collect(f₀), σ₀, n +end + +function Statistics.cov(estimates...) + X = hcat([r.IC for r in estimates]...) + return Statistics.cov(X, dims=1, corrected=true) +end \ No newline at end of file diff --git a/test/composition.jl b/test/composition.jl index 47ca69a7..67c6047f 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -60,13 +60,19 @@ end treatment_values = (T=0,), treatment_confounders = (T=[:W],) ) + # Estimate Individually CM_tmle_result₀, cache = tmle(CM₀, dataset; cache=cache, verbosity=0) CM_ose_result₀, cache = ose(CM₀, dataset; cache=cache, verbosity=0) - # Composition of TMLE - + # Compose estimates CM_result_composed_tmle = compose(-, CM_tmle_result₁, CM_tmle_result₀); CM_result_composed_ose = compose(-, CM_ose_result₁, CM_ose_result₀); - + # Estimate via ComposedEstimand + composed_estimand = ComposedEstimand(-, (CM₁, CM₀)) + composed_estimate, cache = tmle(composed_estimand, dataset; cache=cache, verbosity=0) + @test composed_estimate.estimand == CM_result_composed_tmle.estimand + @test CM_result_composed_tmle.estimate == composed_estimate.estimate + @test CM_result_composed_tmle.std == composed_estimate.std + @test CM_result_composed_tmle.n == composed_estimate.n # Via ATE ATE₁₀ = ATE( outcome = :Y, From 6ed3afabaa4eac1ce1a3587ec0952b754a98ac4b Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 23 Nov 2023 11:03:22 +0000 Subject: [PATCH 131/151] remove exports of some functions --- ext/JSONExt.jl | 4 +- ext/YAMLExt.jl | 4 +- src/TMLE.jl | 4 +- test/configuration.jl | 42 +++++++++---------- test/counterfactual_mean_based/adjustment.jl | 6 +-- test/counterfactual_mean_based/estimands.jl | 16 +++---- .../non_regression_test.jl | 12 +++--- test/scm.jl | 6 +-- 8 files changed, 44 insertions(+), 50 deletions(-) diff --git a/ext/JSONExt.jl b/ext/JSONExt.jl index b50019bb..d1f49082 100644 --- a/ext/JSONExt.jl +++ b/ext/JSONExt.jl @@ -11,7 +11,7 @@ Loads a YAML configuration file containing: - An SCM (optional) - An Adjustment Method (optional) """ -TMLE.read_json(file) = from_dict!(JSON.parsefile(file, dicttype=Dict{Symbol, Any})) +TMLE.read_json(file) = TMLE.from_dict!(JSON.parsefile(file, dicttype=Dict{Symbol, Any})) """ write_json(file, config::Configuration) @@ -21,7 +21,7 @@ with `read_yaml`. """ function TMLE.write_json(file, config; indent=1) open(file, "w+") do io - write(io, JSON.json(to_dict(config), indent)) + write(io, JSON.json(TMLE.to_dict(config), indent)) end end diff --git a/ext/YAMLExt.jl b/ext/YAMLExt.jl index 86ab8a3e..1a06db27 100644 --- a/ext/YAMLExt.jl +++ b/ext/YAMLExt.jl @@ -11,7 +11,7 @@ Loads a YAML configuration file containing: - An SCM (optional) - An Adjustment Method (optional) """ -TMLE.read_yaml(file) = from_dict!(YAML.load_file(file, dicttype=Dict{Symbol, Any})) +TMLE.read_yaml(file) = TMLE.from_dict!(YAML.load_file(file, dicttype=Dict{Symbol, Any})) """ write_yaml(file, config::Configuration) @@ -19,6 +19,6 @@ TMLE.read_yaml(file) = from_dict!(YAML.load_file(file, dicttype=Dict{Symbol, Any Writes a `Configuration` struct to a YAML file. The latter can be deserialized with `read_yaml`. """ -TMLE.write_yaml(file, config) = YAML.write_file(file, to_dict(config)) +TMLE.write_yaml(file, config) = YAML.write_file(file, TMLE.to_dict(config)) end \ No newline at end of file diff --git a/src/TMLE.jl b/src/TMLE.jl index 2ae0fec5..5630429c 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -35,9 +35,7 @@ export compose export TreatmentTransformer, with_encoder, encoder export BackdoorAdjustment, identify export last_fluctuation_epsilon -export read_yaml, write_yaml -export read_json, write_json -export to_dict, from_dict!, Configuration +export Configuration export brute_force_ordering, groups_ordering # ############################################################################# diff --git a/test/configuration.jl b/test/configuration.jl index fbda0c4d..7757bfad 100644 --- a/test/configuration.jl +++ b/test/configuration.jl @@ -17,11 +17,9 @@ function check_estimands(reconstructed_estimands, estimands) end end -outprefix = "test_serialized." -yamlfilename = string(outprefix, "yaml") -jsonfilename = string(outprefix, "json") - @testset "Test Configurations" begin + yamlfilename = mktemp()[1] + jsonfilename = mktemp()[1] estimands = [ IATE( outcome=:Y1, @@ -52,34 +50,34 @@ jsonfilename = string(outprefix, "json") treatment_confounders=(T3 = [:W], T1 = [:W]) ) ] - write_yaml(yamlfilename, estimands) - estimands_from_yaml = read_yaml(yamlfilename) + TMLE.write_yaml(yamlfilename, estimands) + estimands_from_yaml = TMLE.read_yaml(yamlfilename) @test estimands_from_yaml == estimands - rm(yamlfilename) - write_json(jsonfilename, estimands) - estimands_from_json = read_json(jsonfilename) + + TMLE.write_json(jsonfilename, estimands) + estimands_from_json = TMLE.read_json(jsonfilename) @test estimands_from_json == estimands - rm(jsonfilename) configuration = Configuration( estimands=estimands ) - config_dict = to_dict(configuration) - write_yaml(yamlfilename, configuration) - write_json(jsonfilename, configuration) - config_from_yaml = read_yaml(yamlfilename) - config_from_json = read_json(jsonfilename) + config_dict = TMLE.to_dict(configuration) + TMLE.write_yaml(yamlfilename, configuration) + TMLE.write_json(jsonfilename, configuration) + config_from_yaml = TMLE.read_yaml(yamlfilename) + config_from_json = TMLE.read_json(jsonfilename) for loaded_config in (config_from_yaml, config_from_json) @test loaded_config.scm === nothing @test loaded_config.adjustment === nothing reconstructed_estimands = loaded_config.estimands @test reconstructed_estimands == estimands end - rm(yamlfilename) - rm(jsonfilename) + end @testset "Test with an SCM and causal estimands" begin + yamlfilename = mktemp()[1] + jsonfilename = mktemp()[1] # With a StaticSCM, some Causal estimands and an Adjustment Method estimands = [ IATE( @@ -110,11 +108,11 @@ end treatments = [:T1, :T2]), adjustment=BackdoorAdjustment([:C1, :C2]) ) - write_yaml(yamlfilename, configuration) - write_json(jsonfilename, configuration) + TMLE.write_yaml(yamlfilename, configuration) + TMLE.write_json(jsonfilename, configuration) - config_from_yaml = read_yaml(yamlfilename) - config_from_json = read_json(jsonfilename) + config_from_yaml = TMLE.read_yaml(yamlfilename) + config_from_json = TMLE.read_json(jsonfilename) for loaded_config in (config_from_yaml, config_from_json) scm = loaded_config.scm @@ -128,8 +126,6 @@ end estimand₂ = loaded_config.estimands[2] @test estimand₂ == estimands[2] end - rm(yamlfilename) - rm(jsonfilename) end end diff --git a/test/counterfactual_mean_based/adjustment.jl b/test/counterfactual_mean_based/adjustment.jl index 29638d98..04dac9cd 100644 --- a/test/counterfactual_mean_based/adjustment.jl +++ b/test/counterfactual_mean_based/adjustment.jl @@ -29,14 +29,14 @@ using TMLE @test statistical_estimands[4].treatment_confounders == (T₁=(:W₁, :W₂), T₂=(:W₁, :W₂)) end -@testset "Test to_dict" begin +@testset "Test TMLE.to_dict" begin adjustment = BackdoorAdjustment(outcome_extra_covariates=[:C]) - adjustment_dict = to_dict(adjustment) + adjustment_dict = TMLE.to_dict(adjustment) @test adjustment_dict == Dict( :outcome_extra_covariates => [:C], :type => "BackdoorAdjustment" ) - @test from_dict!(adjustment_dict) == adjustment + @test TMLE.from_dict!(adjustment_dict) == adjustment end end diff --git a/test/counterfactual_mean_based/estimands.jl b/test/counterfactual_mean_based/estimands.jl index 0e2d6da0..7f4aca81 100644 --- a/test/counterfactual_mean_based/estimands.jl +++ b/test/counterfactual_mean_based/estimands.jl @@ -136,13 +136,13 @@ end outcome=:y, treatment_values = (T₁=1, T₂="AC") ) - d = to_dict(Ψ) + d = TMLE.to_dict(Ψ) @test d == Dict( :type => "CM", :treatment_values => Dict(:T₁=>1, :T₂=>"AC"), :outcome => :y ) - Ψreconstructed = from_dict!(d) + Ψreconstructed = TMLE.from_dict!(d) @test Ψreconstructed == Ψ # Causal ATE Ψ = ATE( @@ -153,7 +153,7 @@ end T₃=(case="C", control="D") ), ) - d = to_dict(Ψ) + d = TMLE.to_dict(Ψ) @test d == Dict( :type => "ATE", :treatment_values => Dict( @@ -163,7 +163,7 @@ end ), :outcome => :y ) - Ψreconstructed = from_dict!(d) + Ψreconstructed = TMLE.from_dict!(d) @test Ψ == Ψ # Statistical CM @@ -172,7 +172,7 @@ end treatment_values = (T₁=1, T₂="AC"), treatment_confounders = (T₁=[:W₁₁, :W₁₂], T₂=[:W₁₂, :W₂₂]) ) - d = to_dict(Ψ) + d = TMLE.to_dict(Ψ) @test d == Dict( :outcome_extra_covariates => [], :type => "CM", @@ -180,7 +180,7 @@ end :outcome => :y, :treatment_confounders => Dict(:T₁=>[:W₁₁, :W₁₂], :T₂=>[:W₁₂, :W₂₂]) ) - Ψreconstructed = from_dict!(d) + Ψreconstructed = TMLE.from_dict!(d) @test Ψreconstructed == Ψ # Statistical IATE @@ -190,7 +190,7 @@ end treatment_confounders = (T₁=[:W₁₁, :W₁₂], T₂=[:W₁₂, :W₂₂]), outcome_extra_covariates=[:C] ) - d = to_dict(Ψ) + d = TMLE.to_dict(Ψ) @test d == Dict( :outcome_extra_covariates => [:C], :type => "IATE", @@ -198,7 +198,7 @@ end :outcome => :y, :treatment_confounders => Dict(:T₁=>[:W₁₁, :W₁₂], :T₂=>[:W₁₂, :W₂₂]) ) - Ψreconstructed = from_dict!(d) + Ψreconstructed = TMLE.from_dict!(d) @test Ψreconstructed == Ψ end end diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index 240a8f4a..646e5476 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -51,15 +51,15 @@ end tmle_result, cache = tmle(Ψ, dataset; verbosity=verbosity); regression_tests(tmle_result) if VERSION >= v"1.9" - write_json("result.json", [tmle_result]) - results_from_json = read_json("result.json") + jsonfile = mktemp()[1] + TMLE.write_json(jsonfile, [tmle_result]) + results_from_json = TMLE.read_json(jsonfile) regression_tests(results_from_json[1]) - rm("result.json") - write_yaml("result.yaml", [emptyIC(tmle_result)]) - results_from_yaml = read_yaml("result.yaml") + yamlfile = mktemp()[1] + TMLE.write_yaml(yamlfile, [emptyIC(tmle_result)]) + results_from_yaml = TMLE.read_yaml(yamlfile) regression_tests(results_from_yaml[1]) - rm("result.yaml") end # Naive naive = NAIVE(with_encoder(LinearBinaryClassifier())) diff --git a/test/scm.jl b/test/scm.jl index 65dfc333..1a424916 100644 --- a/test/scm.jl +++ b/test/scm.jl @@ -19,7 +19,7 @@ using TMLE expected_vertices = Set([:Y, :T₁, :T₂, :W₁, :W₂, :C]) @test Set(vertices(scm)) == expected_vertices - scm_dict = to_dict(scm) + scm_dict = TMLE.to_dict(scm) @test scm_dict == Dict( :type => "SCM", :equations => [ @@ -31,7 +31,7 @@ using TMLE Dict(:parents => [:W₂], :outcome => :T₂) ] ) - reconstructed_scm = from_dict!(scm_dict) + reconstructed_scm = TMLE.from_dict!(scm_dict) for vertexlabel ∈ expected_vertices @test parents(reconstructed_scm, vertexlabel) == parents(scm, vertexlabel) end @@ -61,7 +61,7 @@ end @test parents(scm, :Y₁) == parents(scm, :Y₂) == Set([:T₁, :T₂, :T₃, :W₁, :W₂]) @test parents(scm, :T₁) == parents(scm, :T₂) == parents(scm, :T₃) == Set([:W₁, :W₂]) - reconstructed_scm = from_dict!(Dict( + reconstructed_scm = TMLE.from_dict!(Dict( :type => "StaticSCM", :outcomes => [:Y₁, :Y₂], :treatments => [:T₁, :T₂, :T₃], From b49b6b4777eb5b9a59de2eb7d855825e5bd1b3b7 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 23 Nov 2023 11:59:47 +0000 Subject: [PATCH 132/151] add serialization of composite struct --- src/estimands.jl | 12 +++++++++++- test/composition.jl | 16 ++++++++++++++++ test/configuration.jl | 22 ++++++++++++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) diff --git a/src/estimands.jl b/src/estimands.jl index af3b4290..3b3d0589 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -109,4 +109,14 @@ const ExpectedValue = ConditionalDistribution struct ComposedEstimand <: Estimand f::Function args::Tuple -end \ No newline at end of file +end + +ComposedEstimand(f::String, args::AbstractVector) = ComposedEstimand(eval(Meta.parse(f)), Tuple(args)) + +ComposedEstimand(;f, args) = ComposedEstimand(f, args) + +to_dict(Ψ::ComposedEstimand) = Dict( + :type => string(ComposedEstimand), + :f => string(nameof(Ψ.f)), + :args => [to_dict(x) for x in Ψ.args] +) \ No newline at end of file diff --git a/test/composition.jl b/test/composition.jl index 67c6047f..a2581974 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -36,6 +36,22 @@ end @test Σ == cov(X) end +@testset "Test to_dict and from_dict!" begin + ATE₁ = ATE( + outcome=:Y, + treatment_values = (T=(case=1, control=0),), + treatment_confounders = (T=[:W],) + ) + ATE₂ = ATE( + outcome=:Y, + treatment_values = (T=(case=2, control=1),), + treatment_confounders = (T=[:W],) + ) + diff = ComposedEstimand(-, (ATE₁, ATE₂)) + d = TMLE.to_dict(diff) + diff_from_dict = TMLE.from_dict!(d) + @test diff_from_dict == diff +end @testset "Test composition CM(1) - CM(0) = ATE(1,0)" begin dataset = make_dataset(;n=1000) # Counterfactual Mean T = 1 diff --git a/test/configuration.jl b/test/configuration.jl index 7757bfad..bf1d9e7f 100644 --- a/test/configuration.jl +++ b/test/configuration.jl @@ -4,6 +4,7 @@ using Test using TMLE using YAML using JSON +using Serialization function check_estimands(reconstructed_estimands, estimands) for (index, estimand) in enumerate(estimands) @@ -128,6 +129,27 @@ end end end +@testset "Test serialization" begin + ATE₁ = ATE( + outcome=:Y, + treatment_values = (T=(case=1, control=0),), + treatment_confounders = (T=[:W],) + ) + ATE₂ = ATE( + outcome=:Y, + treatment_values = (T=(case=2, control=1),), + treatment_confounders = (T=[:W],) + ) + diff = ComposedEstimand(-, (ATE₁, ATE₂)) + jlsfile = mktemp()[1] + serialize(jlsfile, estimands) + estimands_from_jls = deserialize(jlsfile) + @test estimands_from_jls[1] == estimands[1] + @test estimands_from_jls[2] == estimands[2] + @test estimands_from_jls[3] == estimands[3] +end + + end true \ No newline at end of file From 47873b85fb9446cfca3d4d7568855b727834c5e9 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 23 Nov 2023 16:36:14 +0000 Subject: [PATCH 133/151] fix serialization test --- test/configuration.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/configuration.jl b/test/configuration.jl index bf1d9e7f..01e43ae9 100644 --- a/test/configuration.jl +++ b/test/configuration.jl @@ -141,6 +141,7 @@ end treatment_confounders = (T=[:W],) ) diff = ComposedEstimand(-, (ATE₁, ATE₂)) + estimands = [ATE₁, ATE₂, diff] jlsfile = mktemp()[1] serialize(jlsfile, estimands) estimands_from_jls = deserialize(jlsfile) From e5b52daf687c9c19c59313ab0365fd8ad6099f1d Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Thu, 23 Nov 2023 17:57:54 +0000 Subject: [PATCH 134/151] change nuisance function counts function --- src/counterfactual_mean_based/estimands.jl | 10 ++++++++-- src/estimand_ordering.jl | 14 ++++++-------- src/estimands.jl | 5 ++++- test/estimand_ordering.jl | 2 +- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index cf53c876..39da66f5 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -119,12 +119,18 @@ function indicator_fns(Ψ::StatisticalIATE) return Dict(key_vals...) end +expected_value(Ψ::StatisticalCMCompositeEstimand) = ExpectedValue(Ψ.outcome, Tuple(union(Ψ.outcome_extra_covariates, keys(Ψ.treatment_confounders), (Ψ.treatment_confounders)...))) +propensity_score(Ψ::StatisticalCMCompositeEstimand) = Tuple(ConditionalDistribution(T, Ψ.treatment_confounders[T]) for T in treatments(Ψ)) + function get_relevant_factors(Ψ::StatisticalCMCompositeEstimand) - outcome_model = ExpectedValue(Ψ.outcome, Tuple(union(Ψ.outcome_extra_covariates, keys(Ψ.treatment_confounders), (Ψ.treatment_confounders)...))) - treatment_factors = Tuple(ConditionalDistribution(T, Ψ.treatment_confounders[T]) for T in treatments(Ψ)) + outcome_model = expected_value(Ψ) + treatment_factors = propensity_score(Ψ) return CMRelevantFactors(outcome_model, treatment_factors) end +nuisance_functions_iterator(Ψ::StatisticalCMCompositeEstimand) = + (propensity_score(Ψ)..., expected_value(Ψ)) + function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: StatisticalCMCompositeEstimand param_string = string( Base.typename(T).wrapper, diff --git a/src/estimand_ordering.jl b/src/estimand_ordering.jl index a243243e..65158abf 100644 --- a/src/estimand_ordering.jl +++ b/src/estimand_ordering.jl @@ -1,4 +1,4 @@ -function maybe_update_counts!(dict, key) +function update_counts!(dict, key) if haskey(dict, key) dict[key] += 1 else @@ -6,14 +6,12 @@ function maybe_update_counts!(dict, key) end end -function nuisance_counts(estimands) +function nuisance_function_counts(estimands) η_counts = Dict() for Ψ in estimands - η = TMLE.get_relevant_factors(Ψ) - for ps in η.propensity_score - maybe_update_counts!(η_counts, ps) + for η in nuisance_functions_iterator(Ψ) + update_counts!(η_counts, η) end - maybe_update_counts!(η_counts, η.outcome_mean) end return η_counts end @@ -101,7 +99,7 @@ function propensity_score_group_based_permutation_generator(estimands, estimands end """ - brute_force_ordering(estimands; η_counts = nuisance_counts(estimands)) + brute_force_ordering(estimands; η_counts = nuisance_function_counts(estimands)) Finds an optimal ordering of the estimands to minimize maximum cache size. The approach is a brute force one, all permutations are generated and evaluated, @@ -111,7 +109,7 @@ the shuffling, this is actually expected to be much smaller than that. """ function brute_force_ordering(estimands; permutation_generator = estimands_permutation_generator(estimands), - η_counts=nuisance_counts(estimands), + η_counts=nuisance_function_counts(estimands), do_shuffle=true, rng=Random.default_rng(), verbosity=0 diff --git a/src/estimands.jl b/src/estimands.jl index 3b3d0589..f2cb0439 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -119,4 +119,7 @@ to_dict(Ψ::ComposedEstimand) = Dict( :type => string(ComposedEstimand), :f => string(nameof(Ψ.f)), :args => [to_dict(x) for x in Ψ.args] -) \ No newline at end of file +) + +nuisance_functions_iterator(Ψ::ComposedEstimand) = + Iterators.flatten(nuisance_functions_iterator(arg) for arg in Ψ.args) diff --git a/test/estimand_ordering.jl b/test/estimand_ordering.jl index 49084266..e18ff7e3 100644 --- a/test/estimand_ordering.jl +++ b/test/estimand_ordering.jl @@ -58,7 +58,7 @@ statistical_estimands = [identify(x, scm) for x in causal_estimands] # 7 || (T₃, Y₂|T₃) # 8 || (T₁, T₃, Y₂|T₁,T₃) # ---------------------------------- - η_counts = TMLE.nuisance_counts(statistical_estimands) + η_counts = TMLE.nuisance_function_counts(statistical_estimands) @test η_counts == Dict( TMLE.ConditionalDistribution(:Y₂, (:T₁, :W₁, :W₂)) => 1, TMLE.ConditionalDistribution(:Y₂, (:T₁, :T₃, :W₁, :W₂, :W₃)) => 1, From 320f4286c268ad93e1619ad46c266cc11fe5d737 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 24 Nov 2023 16:47:14 +0000 Subject: [PATCH 135/151] add ordering support for composed estimands --- src/TMLE.jl | 2 +- src/adjustment.jl | 28 +++++++ src/counterfactual_mean_based/adjustment.jl | 46 ---------- src/counterfactual_mean_based/estimands.jl | 33 +++++++- src/estimand_ordering.jl | 84 +++++++++---------- src/estimands.jl | 8 ++ .../adjustment.jl | 0 test/estimand_ordering.jl | 48 ++++++++++- test/estimands.jl | 18 ++++ test/runtests.jl | 4 +- 10 files changed, 175 insertions(+), 96 deletions(-) create mode 100644 src/adjustment.jl delete mode 100644 src/counterfactual_mean_based/adjustment.jl rename test/{counterfactual_mean_based => }/adjustment.jl (100%) diff --git a/src/TMLE.jl b/src/TMLE.jl index 5630429c..303af44f 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -44,6 +44,7 @@ export brute_force_ordering, groups_ordering include("utils.jl") include("scm.jl") +include("adjustment.jl") include("estimands.jl") include("estimators.jl") include("estimates.jl") @@ -56,7 +57,6 @@ include("counterfactual_mean_based/fluctuation.jl") include("counterfactual_mean_based/estimators.jl") include("counterfactual_mean_based/clever_covariate.jl") include("counterfactual_mean_based/gradient.jl") -include("counterfactual_mean_based/adjustment.jl") include("configuration.jl") diff --git a/src/adjustment.jl b/src/adjustment.jl new file mode 100644 index 00000000..dbc8f412 --- /dev/null +++ b/src/adjustment.jl @@ -0,0 +1,28 @@ +##################################################################### +### Identification Methods ### +##################################################################### + +abstract type AdjustmentMethod end + +struct BackdoorAdjustment <: AdjustmentMethod + outcome_extra_covariates::Tuple{Vararg{Symbol}} + BackdoorAdjustment(outcome_extra_covariates) = new(unique_sorted_tuple(outcome_extra_covariates)) +end + +BackdoorAdjustment(;outcome_extra_covariates=()) = + BackdoorAdjustment(outcome_extra_covariates) + +function statistical_type_from_causal_type(T) + typestring = string(Base.typename(T).wrapper) + new_typestring = replace(typestring, "TMLE.Causal" => "") + return eval(Symbol(new_typestring)) +end + +identify(estimand, scm::SCM; method=BackdoorAdjustment()::BackdoorAdjustment) = + identify(method, estimand, scm) + + +to_dict(adjustment::BackdoorAdjustment) = Dict( + :type => "BackdoorAdjustment", + :outcome_extra_covariates => collect(adjustment.outcome_extra_covariates) +) \ No newline at end of file diff --git a/src/counterfactual_mean_based/adjustment.jl b/src/counterfactual_mean_based/adjustment.jl deleted file mode 100644 index d8a42c43..00000000 --- a/src/counterfactual_mean_based/adjustment.jl +++ /dev/null @@ -1,46 +0,0 @@ -##################################################################### -### Identification Methods ### -##################################################################### - -abstract type AdjustmentMethod end - -struct BackdoorAdjustment <: AdjustmentMethod - outcome_extra_covariates::Tuple{Vararg{Symbol}} - BackdoorAdjustment(outcome_extra_covariates) = new(unique_sorted_tuple(outcome_extra_covariates)) -end - -BackdoorAdjustment(;outcome_extra_covariates=()) = - BackdoorAdjustment(outcome_extra_covariates) - -function statistical_type_from_causal_type(T) - typestring = string(Base.typename(T).wrapper) - new_typestring = replace(typestring, "TMLE.Causal" => "") - return eval(Symbol(new_typestring)) -end - -identify(causal_estimand::T, scm::SCM; method=BackdoorAdjustment()::BackdoorAdjustment) where T<:CausalCMCompositeEstimands = - identify(method, causal_estimand, scm) - - -function identify(method::BackdoorAdjustment, causal_estimand::T, scm::SCM) where T<:CausalCMCompositeEstimands - # Treatment confounders - treatment_names = keys(causal_estimand.treatment_values) - treatment_codes = [code_for(scm.graph, treatment) for treatment ∈ treatment_names] - confounders_codes = scm.graph.graph.badjlist[treatment_codes] - treatment_confounders = NamedTuple{treatment_names}( - [[scm.graph.vertex_labels[w] for w in confounders_codes[i]] - for i in eachindex(confounders_codes)] - ) - - return statistical_type_from_causal_type(T)(; - outcome=causal_estimand.outcome, - treatment_values = causal_estimand.treatment_values, - treatment_confounders = treatment_confounders, - outcome_extra_covariates = method.outcome_extra_covariates - ) -end - -to_dict(adjustment::BackdoorAdjustment) = Dict( - :type => "BackdoorAdjustment", - :outcome_extra_covariates => collect(adjustment.outcome_extra_covariates) -) \ No newline at end of file diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index 39da66f5..270f623d 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -119,17 +119,24 @@ function indicator_fns(Ψ::StatisticalIATE) return Dict(key_vals...) end -expected_value(Ψ::StatisticalCMCompositeEstimand) = ExpectedValue(Ψ.outcome, Tuple(union(Ψ.outcome_extra_covariates, keys(Ψ.treatment_confounders), (Ψ.treatment_confounders)...))) +outcome_mean(Ψ::StatisticalCMCompositeEstimand) = ExpectedValue(Ψ.outcome, Tuple(union(Ψ.outcome_extra_covariates, keys(Ψ.treatment_confounders), (Ψ.treatment_confounders)...))) + +outcome_mean_key(Ψ::StatisticalCMCompositeEstimand) = variables(outcome_mean(Ψ)) + propensity_score(Ψ::StatisticalCMCompositeEstimand) = Tuple(ConditionalDistribution(T, Ψ.treatment_confounders[T]) for T in treatments(Ψ)) +propensity_score_key(Ψ::StatisticalCMCompositeEstimand) = Tuple(variables(x) for x ∈ propensity_score(Ψ)) + function get_relevant_factors(Ψ::StatisticalCMCompositeEstimand) - outcome_model = expected_value(Ψ) + outcome_model = outcome_mean(Ψ) treatment_factors = propensity_score(Ψ) return CMRelevantFactors(outcome_model, treatment_factors) end +n_uniques_nuisance_functions(Ψ::StatisticalCMCompositeEstimand) = length(propensity_score(Ψ)) + 1 + nuisance_functions_iterator(Ψ::StatisticalCMCompositeEstimand) = - (propensity_score(Ψ)..., expected_value(Ψ)) + (propensity_score(Ψ)..., outcome_mean(Ψ)) function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: StatisticalCMCompositeEstimand param_string = string( @@ -193,4 +200,24 @@ function to_dict(Ψ::T) where T <: StatisticalCMCompositeEstimand :treatment_confounders => treatment_confounders_to_dict(Ψ.treatment_confounders), :outcome_extra_covariates => collect(Ψ.outcome_extra_covariates) ) +end + +identify(method::AdjustmentMethod, Ψ::StatisticalCMCompositeEstimand, scm::SCM) = Ψ + +function identify(method::BackdoorAdjustment, causal_estimand::T, scm::SCM) where T<:CausalCMCompositeEstimands + # Treatment confounders + treatment_names = keys(causal_estimand.treatment_values) + treatment_codes = [code_for(scm.graph, treatment) for treatment ∈ treatment_names] + confounders_codes = scm.graph.graph.badjlist[treatment_codes] + treatment_confounders = NamedTuple{treatment_names}( + [[scm.graph.vertex_labels[w] for w in confounders_codes[i]] + for i in eachindex(confounders_codes)] + ) + + return statistical_type_from_causal_type(T)(; + outcome=causal_estimand.outcome, + treatment_values = causal_estimand.treatment_values, + treatment_confounders = treatment_confounders, + outcome_extra_covariates = method.outcome_extra_covariates + ) end \ No newline at end of file diff --git a/src/estimand_ordering.jl b/src/estimand_ordering.jl index 65158abf..21db0dc6 100644 --- a/src/estimand_ordering.jl +++ b/src/estimand_ordering.jl @@ -33,23 +33,22 @@ function evaluate_proxy_costs(estimands, η_counts; verbosity=0) maxmem = 0 compcost = 0 for (estimand_index, Ψ) ∈ enumerate(estimands) - η = get_relevant_factors(Ψ) - models = (η.propensity_score..., η.outcome_mean) # Append cache - for model in models - if model ∉ cache + ηs = collect(nuisance_functions_iterator(Ψ)) + for η ∈ ηs + if η ∉ cache compcost += 1 - push!(cache, model) + push!(cache, η) end end # Update maxmem maxmem = max(maxmem, length(cache)) verbosity > 0 && @info string("Cache size after estimand $estimand_index: ", length(cache)) # Free cache from models that are not useful anymore - for model in models - η_counts[model] -= 1 - if η_counts[model] <= 0 - pop!(cache, model) + for η in ηs + η_counts[η] -= 1 + if η_counts[η] <= 0 + pop!(cache, η) end end end @@ -66,8 +65,7 @@ on the cache size. It can be computed in a single pass, i.e. in O(N). function get_min_maxmem_lowerbound(estimands) min_maxmem_lowerbound = 0 for Ψ in estimands - η = get_relevant_factors(Ψ) - candidate_min = length((η.propensity_score..., η.outcome_mean)) + candidate_min = n_uniques_nuisance_functions(Ψ) if candidate_min > min_maxmem_lowerbound min_maxmem_lowerbound = candidate_min end @@ -77,25 +75,9 @@ end estimands_permutation_generator(estimands) = Combinatorics.permutations(estimands) -function get_propensity_score_groups(estimands_and_nuisances) - ps_groups = [[]] - current_ps = estimands_and_nuisances[1][2] - for (index, Ψ_and_ηs) ∈ enumerate(estimands_and_nuisances) - new_ps = Ψ_and_ηs[2] - if new_ps == current_ps - push!(ps_groups[end], index) - else - current_ps = new_ps - push!(ps_groups, [index]) - end - end - return ps_groups -end - -function propensity_score_group_based_permutation_generator(estimands, estimands_and_nuisances) - ps_groups = get_propensity_score_groups(estimands_and_nuisances) - group_permutations = Combinatorics.permutations(collect(1:length(ps_groups))) - return (vcat((estimands[ps_groups[index]] for index in group_perm)...) for group_perm in group_permutations) +function propensity_score_group_based_permutation_generator(groups) + group_permutations = Combinatorics.permutations(collect(keys(groups))) + return (vcat((groups[ps_key] for ps_key in group_perm)...) for group_perm in group_permutations) end """ @@ -133,6 +115,32 @@ function brute_force_ordering(estimands; return optimal_ordering end +""" + groupby_by_propensity_score(estimands) + +Group parameters per propensity score and order each group by outcome_mean. +""" +function groupby_by_propensity_score(estimands) + groups = Dict() + for Ψ in estimands + propensity_score_key_ = propensity_score_key(Ψ) + outcome_mean_key_ = outcome_mean_key(Ψ) + if haskey(groups, propensity_score_key_) + propensity_score_group = groups[propensity_score_key_] + if haskey(propensity_score_group, outcome_mean_key_) + push!(propensity_score_group[outcome_mean_key_], Ψ) + else + propensity_score_group[outcome_mean_key_] = Any[Ψ] + end + else + groups[propensity_score_key_] = Dict() + groups[propensity_score_key_][outcome_mean_key_] = Any[Ψ] + end + + end + return Dict(key => vcat(values(groups[key])...) for key in keys(groups)) +end + """ groups_ordering(estimands) @@ -142,26 +150,18 @@ work reasonably well in practice. It could be optimized further by: - Brute forcing the ordering of these groups to find an optimal one. """ function groups_ordering(estimands; brute_force=false, do_shuffle=true, rng=Random.default_rng(), verbosity=0) - # Sort estimands based on propensity_score first and outcome_mean second - estimands_and_nuisances = [] - for Ψ in estimands - η = TMLE.get_relevant_factors(Ψ) - push!(estimands_and_nuisances, (Ψ, η.propensity_score, η.outcome_mean)) - end - sort!(estimands_and_nuisances, by = x -> (Tuple(TMLE.variables(ps) for ps in x[2]), TMLE.variables(x[3]))) - - # Sorted estimands only - estimands = [x[1] for x in estimands_and_nuisances] + # Group estimands based on propensity_score first and outcome_mean second + groups = groupby_by_propensity_score(estimands) # Brute force on the propensity score groups if brute_force return brute_force_ordering(estimands; - permutation_generator = propensity_score_group_based_permutation_generator(estimands, estimands_and_nuisances), + permutation_generator = propensity_score_group_based_permutation_generator(groups), do_shuffle=do_shuffle, rng=rng, verbosity=verbosity ) else - return estimands + return vcat(values(groups)...) end end \ No newline at end of file diff --git a/src/estimands.jl b/src/estimands.jl index f2cb0439..2c442811 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -121,5 +121,13 @@ to_dict(Ψ::ComposedEstimand) = Dict( :args => [to_dict(x) for x in Ψ.args] ) +propensity_score_key(Ψ::ComposedEstimand) = Tuple(unique(Iterators.flatten(propensity_score_key(arg) for arg in Ψ.args))) +outcome_mean_key(Ψ::ComposedEstimand) = Tuple(unique(outcome_mean_key(arg) for arg in Ψ.args)) + +n_uniques_nuisance_functions(Ψ::ComposedEstimand) = length(propensity_score_key(Ψ)) + length(outcome_mean_key(Ψ)) + nuisance_functions_iterator(Ψ::ComposedEstimand) = Iterators.flatten(nuisance_functions_iterator(arg) for arg in Ψ.args) + +identify(method::AdjustmentMethod, Ψ::ComposedEstimand, scm::SCM) = + ComposedEstimand(Ψ.f, Tuple(identify(method, arg, scm) for arg ∈ Ψ.args)) \ No newline at end of file diff --git a/test/counterfactual_mean_based/adjustment.jl b/test/adjustment.jl similarity index 100% rename from test/counterfactual_mean_based/adjustment.jl rename to test/adjustment.jl diff --git a/test/estimand_ordering.jl b/test/estimand_ordering.jl index e18ff7e3..b6fbbe22 100644 --- a/test/estimand_ordering.jl +++ b/test/estimand_ordering.jl @@ -47,6 +47,24 @@ causal_estimands = [ ] statistical_estimands = [identify(x, scm) for x in causal_estimands] +@testset "Test misc" begin + # Test groupby_by_propensity_score + groups = TMLE.groupby_by_propensity_score(statistical_estimands) + @test groups[((:T₂, :W₂),)] == [statistical_estimands[4]] + @test groups[((:T₃, :W₂, :W₃),)] == [statistical_estimands[7]] + @test groups[((:T₁, :W₁, :W₂), (:T₃, :W₂, :W₃))] == [statistical_estimands[end]] + @test Set(groups[((:T₁, :W₁, :W₂),)]) == Set([statistical_estimands[1], statistical_estimands[2], statistical_estimands[6]]) + @test Set(groups[((:T₁, :W₁, :W₂), (:T₂, :W₂))]) == Set([statistical_estimands[3], statistical_estimands[5]]) + @test size(vcat(values(groups)...)) == size(statistical_estimands) + + # Test PS groups permutations + factorial(5) + permutations = TMLE.propensity_score_group_based_permutation_generator(groups) + @test length(permutations) == factorial(length(groups)) + for permutation in permutations + @test Set(permutation) == Set(statistical_estimands) + end +end @testset "Test ordering strategies" begin # Estimand ID || Required models # 1 || (T₁, Y₁|T₁) @@ -72,6 +90,7 @@ statistical_estimands = [identify(x, scm) for x in causal_estimands] ) @test TMLE.evaluate_proxy_costs(statistical_estimands, η_counts) == (4, 9) @test TMLE.get_min_maxmem_lowerbound(statistical_estimands) == 3 + # The brute force solution returns the optimal solution optimal_ordering = @test_logs (:info, "Lower bound reached, stopping.") brute_force_ordering(statistical_estimands, verbosity=1, rng=StableRNG(123)) @test TMLE.evaluate_proxy_costs(optimal_ordering, η_counts) == (3, 9) @@ -79,15 +98,40 @@ statistical_estimands = [identify(x, scm) for x in causal_estimands] bad_ordering = statistical_estimands[[1, 7, 3, 6, 2, 5, 8, 4]] @test TMLE.evaluate_proxy_costs(bad_ordering, η_counts) == (6, 9) # Without the brute force on groups, the solution is not necessarily optimal - # but still widely improved + # but still improved ordering_from_groups = groups_ordering(bad_ordering) @test TMLE.evaluate_proxy_costs(ordering_from_groups, η_counts) == (4, 9) # Adding a layer of brute forcing results in an optimal ordering - ordering_from_groups_with_brute_force = groups_ordering(bad_ordering, brute_force=true) @test TMLE.evaluate_proxy_costs(ordering_from_groups_with_brute_force, η_counts) == (3, 9) end +@testset "Test ordering strategies with Composed Estimands" begin + ATE₁ = ATE( + outcome=:Y₁, + treatment_values=(T₁=(case=1, control=0),) + ) + ATE₂ = ATE( + outcome=:Y₁, + treatment_values=(T₁=(case=2, control=1),) + ) + diff = ComposedEstimand(-, (ATE₁, ATE₂)) + ATE₃ = ATE( + outcome=:Y₁, + treatment_values=(T₂=(case=1, control=0),) + ) + estimands = [identify(x, scm) for x in [ATE₁, ATE₃, diff, ATE₂]] + η_counts = TMLE.nuisance_function_counts(estimands) + @test TMLE.get_min_maxmem_lowerbound(estimands) == 2 + @test TMLE.evaluate_proxy_costs(estimands, η_counts) == (4, 4) + # Brute Force + optimal_ordering = brute_force_ordering(estimands) + @test TMLE.evaluate_proxy_costs(optimal_ordering, η_counts) == (2, 4) + # PS Group + grouped_ordering = groups_ordering(estimands, brute_force=true) + @test TMLE.evaluate_proxy_costs(grouped_ordering, η_counts) == (2, 4) + +end end diff --git a/test/estimands.jl b/test/estimands.jl index f621a1c1..cf91c52e 100644 --- a/test/estimands.jl +++ b/test/estimands.jl @@ -27,6 +27,24 @@ end @test TMLE.variables(η) == (:Y, :T, :W, :T₁, :W₁, :T₂, :W₂₁, :W₂₂) end +@testset "Test ComposedEstimand" begin + ATE₁ = ATE( + outcome=:Y, + treatment_values = (T₁=(case=1, control=0), T₂=(case=1, control=0)), + treatment_confounders = (T₁=[:W], T₂=[:W]) + ) + ATE₂ = ATE( + outcome=:Y, + treatment_values = (T₁=(case=2, control=1), T₂=(case=2, control=1)), + treatment_confounders = (T₁=[:W], T₂=[:W]) + ) + diff = ComposedEstimand(-, (ATE₁, ATE₂)) + + @test TMLE.propensity_score_key(diff) == ((:T₁, :W), (:T₂, :W)) + @test TMLE.outcome_mean_key(diff) == ((:Y, :T₁, :T₂, :W),) +end + + end true \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 9049d620..06c5b165 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,12 +5,13 @@ TEST_DIR = joinpath(pkgdir(TMLE), "test") @time begin @test include(joinpath(TEST_DIR, "utils.jl")) + @test include(joinpath(TEST_DIR, "scm.jl")) + @test include(joinpath(TEST_DIR, "adjustment.jl")) @test include(joinpath(TEST_DIR, "estimands.jl")) @test include(joinpath(TEST_DIR, "estimators_and_estimates.jl")) @test include(joinpath(TEST_DIR, "missing_management.jl")) @test include(joinpath(TEST_DIR, "composition.jl")) @test include(joinpath(TEST_DIR, "treatment_transformer.jl")) - @test include(joinpath(TEST_DIR, "scm.jl")) # Requires Extensions if VERSION >= v"1.9" @test include(joinpath(TEST_DIR, "configuration.jl")) @@ -26,5 +27,4 @@ TEST_DIR = joinpath(pkgdir(TMLE), "test") @test include(joinpath(TEST_DIR, "counterfactual_mean_based/double_robustness_ate.jl")) @test include(joinpath(TEST_DIR, "counterfactual_mean_based/double_robustness_iate.jl")) @test include(joinpath(TEST_DIR, "counterfactual_mean_based/3points_interactions.jl")) - @test include(joinpath(TEST_DIR, "counterfactual_mean_based/adjustment.jl")) end \ No newline at end of file From 3bc0c67acc189202a7e6627c2d88a5e3da852a5e Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 27 Nov 2023 10:52:52 +0000 Subject: [PATCH 136/151] test emptyIC --- src/counterfactual_mean_based/estimates.jl | 12 +++++++++++- .../double_robustness_ate.jl | 7 ++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/counterfactual_mean_based/estimates.jl b/src/counterfactual_mean_based/estimates.jl index 8acb26ed..5936253f 100644 --- a/src/counterfactual_mean_based/estimates.jl +++ b/src/counterfactual_mean_based/estimates.jl @@ -55,9 +55,19 @@ function to_dict(estimate::T) where T <: EICEstimate ) end -emptyIC(estimate::T) where T <: EICEstimate = +emptyIC(estimate::T, pval_threshold::Nothing) where T <: EICEstimate = T(estimate.estimand, estimate.estimate, estimate.std, estimate.n, []) +function emptyIC(estimate::T, pval_threshold::Float64) where T <: EICEstimate + pval = pvalue(OneSampleZTest(estimate)) + return pval < pval_threshold ? estimate : emptyIC(estimate, nothing) +end + +emptyIC(estimate::T; pval_threshold=nothing) where T <: EICEstimate = emptyIC(estimate, pval_threshold) + + + + function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate) testresult = OneSampleTTest(est) data = [estimate(est) confint(testresult) pvalue(testresult);] diff --git a/test/counterfactual_mean_based/double_robustness_ate.jl b/test/counterfactual_mean_based/double_robustness_ate.jl index 236da304..f709581a 100644 --- a/test/counterfactual_mean_based/double_robustness_ate.jl +++ b/test/counterfactual_mean_based/double_robustness_ate.jl @@ -11,7 +11,7 @@ using StableRNGs using StatsBase using LogExpFunctions -include(joinpath(dirname(@__DIR__), "helper_fns.jl")) +include(joinpath(pkgdir(TMLE), "test", "helper_fns.jl")) """ Q and G are two logistic models @@ -120,6 +120,11 @@ end results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) + # Test emptyIC function + @test emptyIC(results.tmle).IC == [] + pval = pvalue(OneSampleZTest(results.tmle)) + @test emptyIC(results.tmle, pval_threshold=0.9pval).IC == [] + @test emptyIC(results.tmle, pval_threshold=1.1pval) === results.tmle # The initial estimate is far away naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) From 018bc47b88a159fc897ced3ff534ddccce37774d Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 27 Nov 2023 13:57:49 +0000 Subject: [PATCH 137/151] remove bias from IC in OSE --- src/counterfactual_mean_based/estimates.jl | 2 -- src/counterfactual_mean_based/estimators.jl | 4 +++- src/estimators.jl | 2 +- .../double_robustness_ate.jl | 10 +++++----- .../double_robustness_iate.jl | 16 ++++++++++------ test/helper_fns.jl | 8 -------- 6 files changed, 19 insertions(+), 23 deletions(-) diff --git a/src/counterfactual_mean_based/estimates.jl b/src/counterfactual_mean_based/estimates.jl index 5936253f..66e7aac6 100644 --- a/src/counterfactual_mean_based/estimates.jl +++ b/src/counterfactual_mean_based/estimates.jl @@ -66,8 +66,6 @@ end emptyIC(estimate::T; pval_threshold=nothing) where T <: EICEstimate = emptyIC(estimate, pval_threshold) - - function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate) testresult = OneSampleTTest(est) data = [estimate(est) confint(testresult) pvalue(testresult);] diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 71c814df..7bb17b89 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -272,10 +272,12 @@ function (estimator::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dic # Gradient and estimate IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) + IC_mean = mean(IC) + IC .-= IC_mean σ̂ = std(IC) n = size(IC, 1) verbosity >= 1 && @info "Done." - return OSEstimate(Ψ, Ψ̂ + mean(IC), σ̂, n, IC), cache + return OSEstimate(Ψ, Ψ̂ + IC_mean, σ̂, n, IC), cache end ##################################################################### diff --git a/src/estimators.jl b/src/estimators.jl index da8ddc31..4dcf382c 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -166,7 +166,7 @@ function _compose(f, estimates...; backend=AD.ZygoteBackend()) return collect(f₀), σ₀, n end -function Statistics.cov(estimates...) +function Statistics.cov(estimates::Vararg{EICEstimate}) X = hcat([r.IC for r in estimates]...) return Statistics.cov(X, dims=1, corrected=true) end \ No newline at end of file diff --git a/test/counterfactual_mean_based/double_robustness_ate.jl b/test/counterfactual_mean_based/double_robustness_ate.jl index f709581a..65626aa0 100644 --- a/test/counterfactual_mean_based/double_robustness_ate.jl +++ b/test/counterfactual_mean_based/double_robustness_ate.jl @@ -119,7 +119,7 @@ end dr_estimators = double_robust_estimators(models) results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) + test_mean_inf_curve_almost_zero(results.ose, ; atol=1e-10) # Test emptyIC function @test emptyIC(results.tmle).IC == [] pval = pvalue(OneSampleZTest(results.tmle)) @@ -155,7 +155,7 @@ end dr_estimators = double_robust_estimators(models) results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) test_mean_inf_curve_almost_zero(results.tmle; atol=1e-6) - test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) + test_mean_inf_curve_almost_zero(results.ose; atol=1e-6) # The initial estimate is far away naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) @@ -187,7 +187,7 @@ end results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) + test_mean_inf_curve_almost_zero(results.ose; atol=1e-10) # The initial estimate is far away naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) @@ -226,7 +226,7 @@ end dr_estimators = double_robust_estimators(models) results, cache = test_coverage_and_get_results(dr_estimators, Ψ, ATE₁₁₋₀₁, dataset; verbosity=0) test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) + test_mean_inf_curve_almost_zero(results.ose; atol=1e-10) # When Q is well specified but G is misspecified models = ( Y = with_encoder(LinearRegressor()), @@ -252,7 +252,7 @@ end dr_estimators = double_robust_estimators(models) results, cache = test_coverage_and_get_results(dr_estimators, Ψ, ATE₁₁₋₀₀, dataset; verbosity=0) test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) + test_mean_inf_curve_almost_zero(results.ose; atol=1e-10) # When Q is well specified but G is misspecified models = ( diff --git a/test/counterfactual_mean_based/double_robustness_iate.jl b/test/counterfactual_mean_based/double_robustness_iate.jl index 6456a148..540278f0 100644 --- a/test/counterfactual_mean_based/double_robustness_iate.jl +++ b/test/counterfactual_mean_based/double_robustness_iate.jl @@ -12,7 +12,7 @@ using MLJModels using MLJLinearModels using LogExpFunctions -include(joinpath(dirname(@__DIR__), "helper_fns.jl")) +include(joinpath(pkgdir(TMLE), "test", "helper_fns.jl")) cont_interacter = InteractionTransformer(order=2) |> LinearRegressor cat_interacter = InteractionTransformer(order=2) |> LogisticClassifier(lambda=1.) @@ -187,7 +187,7 @@ end dr_estimators = double_robust_estimators(models) results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) test_mean_inf_curve_almost_zero(results.tmle; atol=1e-9) - test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) + test_mean_inf_curve_almost_zero(results.ose; atol=1e-9) # The initial estimate is far away naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) @@ -202,7 +202,7 @@ end dr_estimators = double_robust_estimators(models) results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) test_mean_inf_curve_almost_zero(results.tmle; atol=1e-9) - test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) + test_mean_inf_curve_almost_zero(results.ose; atol=1e-9) # The initial estimate is far away naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) @@ -232,7 +232,7 @@ end dr_estimators = double_robust_estimators(models) results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) test_mean_inf_curve_almost_zero(results.tmle; atol=1e-10) - test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) + test_mean_inf_curve_almost_zero(results.ose; atol=1e-10) # The initial estimate is far away naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) @@ -271,7 +271,9 @@ end ) dr_estimators = double_robust_estimators(models) results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) - test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-5) + test_mean_inf_curve_almost_zero(results.ose; atol=1e-10) + # The initial estimate is far away naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) @@ -285,7 +287,9 @@ end ) dr_estimators = double_robust_estimators(models) results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) - test_fluct_mean_inf_curve_lower_than_initial(results.tmle, results.ose) + test_mean_inf_curve_almost_zero(results.tmle; atol=1e-5) + test_mean_inf_curve_almost_zero(results.ose; atol=1e-15) + # The initial estimate is far away naive = NAIVE(models.Y) naive_result, cache = naive(Ψ, dataset; cache=cache, verbosity=0) diff --git a/test/helper_fns.jl b/test/helper_fns.jl index f3f11fe3..ffbfb223 100644 --- a/test/helper_fns.jl +++ b/test/helper_fns.jl @@ -53,14 +53,6 @@ The TMLE is supposed to solve the EIC score equation. """ test_mean_inf_curve_almost_zero(tmle_result::TMLE.EICEstimate; atol=1e-10) = @test mean(tmle_result.IC) ≈ 0.0 atol=atol -""" - test_fluct_mean_inf_curve_lower_than_initial(tmle_result::TMLE.TMLEResult) - -This cqnnot be guaranteed in general since a well specified maximum -likelihood estimator also solves the score equation. -""" -test_fluct_mean_inf_curve_lower_than_initial(tmle_result::TMLE.TMLEstimate, ose_result::TMLE.OSEstimate) = @test abs(mean(tmle_result.IC)) < abs(mean(ose_result.IC)) - double_robust_estimators(models; resampling=CV(nfolds=3)) = ( tmle = TMLEE(models=models, machine_cache=true), ose = OSE(models=models, machine_cache=true), From 21349f591f1e5c8d03bd763b3e2d24c8737cca5c Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 28 Nov 2023 10:05:05 +0000 Subject: [PATCH 138/151] add hotelling test --- src/estimates.jl | 25 ++++++--- src/estimators.jl | 8 +-- test/composition.jl | 52 ++++++++++++++++++- .../double_robustness_iate.jl | 2 +- 4 files changed, 75 insertions(+), 12 deletions(-) diff --git a/src/estimates.jl b/src/estimates.jl index 7c071a1d..47c65020 100644 --- a/src/estimates.jl +++ b/src/estimates.jl @@ -137,9 +137,22 @@ Statistics.var(est::ComposedEstimate) = Performs a T test on the ComposedEstimate. """ -function HypothesisTests.OneSampleTTest(est::ComposedEstimate, Ψ₀=0) - @assert length(est.estimate) == 1 "OneSampleTTest is only implemeted for real-valued statistics." - return OneSampleTTest(estimate(est), sqrt(est.std[1]), est.n, Ψ₀) +function HypothesisTests.OneSampleTTest(estimate::ComposedEstimate, Ψ₀=0) + @assert length(estimate.estimate) == 1 "OneSampleTTest is only implemeted for real-valued statistics." + return OneSampleTTest(estimate.estimate[1], sqrt(estimate.std[1]), estimate.n, Ψ₀) +end + +function HypothesisTests.OneSampleHotellingT2Test(estimate::ComposedEstimate, Ψ₀=zeros(size(estimate.estimate, 1))) + x̄ = estimate.estimate + S = estimate.std + n, p = estimate.n, length(x̄) + p == length(Ψ₀) || + throw(DimensionMismatch("Number of variables does not match number of means")) + n > 0 || throw(ArgumentError("The input must be non-empty")) + + T² = n * HypothesisTests.At_Binv_A(x̄ .- Ψ₀, S) + F = (n - p) * T² / (p * (n - 1)) + return OneSampleHotellingT2Test(T², F, n, p, Ψ₀, x̄, S) end """ @@ -147,7 +160,7 @@ end Performs a T test on the ComposedEstimate. """ -function HypothesisTests.OneSampleZTest(est::ComposedEstimate, Ψ₀=0) - @assert length(est.estimate) == 1 "OneSampleTTest is only implemeted for real-valued statistics." - return OneSampleZTest(estimate(est), sqrt(est.std[1]), est.n, Ψ₀) +function HypothesisTests.OneSampleZTest(estimate::ComposedEstimate, Ψ₀=0) + @assert length(estimate.estimate) == 1 "OneSampleTTest is only implemeted for real-valued statistics." + return OneSampleZTest(estimate.estimate[1], sqrt(estimate.std[1]), estimate.n, Ψ₀) end diff --git a/src/estimators.jl b/src/estimators.jl index 4dcf382c..d971fb4a 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -157,16 +157,16 @@ function compose(f, estimates...; backend=AD.ZygoteBackend()) end function _compose(f, estimates...; backend=AD.ZygoteBackend()) - Σ = cov(estimates...) + Σ = covariance_matrix(estimates...) point_estimates = [r.estimate for r in estimates] f₀, Js = AD.value_and_jacobian(backend, f, point_estimates...) J = hcat(Js...) n = size(first(estimates).IC, 1) - σ₀ = J*Σ*J' + σ₀ = J * Σ * J' return collect(f₀), σ₀, n end -function Statistics.cov(estimates::Vararg{EICEstimate}) +function covariance_matrix(estimates...) X = hcat([r.IC for r in estimates]...) - return Statistics.cov(X, dims=1, corrected=true) + return cov(X, dims=1, corrected=true) end \ No newline at end of file diff --git a/test/composition.jl b/test/composition.jl index a2581974..a9ad56b8 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -7,6 +7,8 @@ using Distributions using MLJLinearModels using TMLE using CategoricalArrays +using LogExpFunctions +using HypothesisTests function make_dataset(;n=100) rng = StableRNG(123) @@ -31,7 +33,7 @@ end X = rand(n, 2) ER₁ = TMLE.TMLEstimate(Ψ, 1., 1., n, X[:, 1]) ER₂ = TMLE.TMLEstimate(Ψ, 0., 1., n, X[:, 2]) - Σ = cov(ER₁, ER₂) + Σ = TMLE.covariance_matrix(ER₁, ER₂) @test size(Σ) == (2, 2) @test Σ == cov(X) end @@ -153,6 +155,54 @@ end @test size(var(CM_result_composed)) == (3, 3) end +@testset "Test Joint Interaction" begin + # Dataset + n = 100 + rng = StableRNG(123) + + W = rand(rng, n) + + θT₁ = rand(rng, Normal(), 3) + pT₁ = softmax(W*θT₁', dims=2) + T₁ = [rand(rng, Categorical(collect(p))) for p in eachrow(pT₁)] + + θT₂ = rand(rng, Normal(), 3) + pT₂ = softmax(W*θT₂', dims=2) + T₂ = [rand(rng, Categorical(collect(p))) for p in eachrow(pT₂)] + + Y = 1 .+ W .+ T₁ .- T₂ .- T₁.*T₂ .+ rand(rng, Normal()) + dataset = ( + W = W, + T₁ = categorical(T₁), + T₂ = categorical(T₂), + Y = Y + ) + IATE₁ = IATE( + outcome = :Y, + treatment_values = (T₁=(case=2, control=1), T₂=(case=2, control=1)), + treatment_confounders = (T₁ = [:W], T₂ = [:W]) + ) + IATE₂ = IATE( + outcome = :Y, + treatment_values = (T₁=(case=3, control=1), T₂=(case=3, control=1)), + treatment_confounders = (T₁ = [:W], T₂ = [:W]) + ) + IATE₃ = IATE( + outcome = :Y, + treatment_values = (T₁=(case=3, control=2), T₂=(case=3, control=2)), + treatment_confounders = (T₁ = [:W], T₂ = [:W]) + ) + + jointIATE = ComposedEstimand((x, y, z) -> [x, y, z], (IATE₁, IATE₂, IATE₃)) + ose = OSE(models=TMLE.default_models(G=LogisticClassifier(), Q_continuous=LinearRegressor())) + jointEstimate, _ = ose(jointIATE, dataset, verbosity=0) + + testres = OneSampleHotellingT2Test(jointEstimate) + @test testres.x̄ ≈ jointEstimate.estimate + @test pvalue(testres) < 1e-10 +end + + end true \ No newline at end of file diff --git a/test/counterfactual_mean_based/double_robustness_iate.jl b/test/counterfactual_mean_based/double_robustness_iate.jl index 540278f0..ec926b35 100644 --- a/test/counterfactual_mean_based/double_robustness_iate.jl +++ b/test/counterfactual_mean_based/double_robustness_iate.jl @@ -288,7 +288,7 @@ end dr_estimators = double_robust_estimators(models) results, cache = test_coverage_and_get_results(dr_estimators, Ψ, Ψ₀, dataset; verbosity=0) test_mean_inf_curve_almost_zero(results.tmle; atol=1e-5) - test_mean_inf_curve_almost_zero(results.ose; atol=1e-15) + test_mean_inf_curve_almost_zero(results.ose; atol=1e-10) # The initial estimate is far away naive = NAIVE(models.Y) From 897d8d447a734c603fa7b7b12a4f4da1bb5300ee Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 28 Nov 2023 10:26:58 +0000 Subject: [PATCH 139/151] test n_counts for composed --- test/estimand_ordering.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/estimand_ordering.jl b/test/estimand_ordering.jl index b6fbbe22..d7bbc430 100644 --- a/test/estimand_ordering.jl +++ b/test/estimand_ordering.jl @@ -122,6 +122,12 @@ end ) estimands = [identify(x, scm) for x in [ATE₁, ATE₃, diff, ATE₂]] η_counts = TMLE.nuisance_function_counts(estimands) + η_counts == Dict( + TMLE.ConditionalDistribution(:Y₁, (:T₂, :W₂)) => 1, + TMLE.ConditionalDistribution(:T₁, (:W₁, :W₂)) => 4, + TMLE.ConditionalDistribution(:Y₁, (:T₁, :W₁, :W₂)) => 4, + TMLE.ConditionalDistribution(:T₂, (:W₂,)) => 1 + ) @test TMLE.get_min_maxmem_lowerbound(estimands) == 2 @test TMLE.evaluate_proxy_costs(estimands, η_counts) == (4, 4) # Brute Force From 245864a89c7bee4f30aa8bebadbab0d1faf8a154 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 28 Nov 2023 10:48:47 +0000 Subject: [PATCH 140/151] add emptyIC for composed estimate --- src/counterfactual_mean_based/estimates.jl | 4 ++-- src/estimands.jl | 2 +- src/estimates.jl | 7 +++++++ test/composition.jl | 20 ++++++++++++++++++++ 4 files changed, 30 insertions(+), 3 deletions(-) diff --git a/src/counterfactual_mean_based/estimates.jl b/src/counterfactual_mean_based/estimates.jl index 66e7aac6..d4f13013 100644 --- a/src/counterfactual_mean_based/estimates.jl +++ b/src/counterfactual_mean_based/estimates.jl @@ -55,7 +55,7 @@ function to_dict(estimate::T) where T <: EICEstimate ) end -emptyIC(estimate::T, pval_threshold::Nothing) where T <: EICEstimate = +emptyIC(estimate::T, ::Nothing) where T <: EICEstimate = T(estimate.estimand, estimate.estimate, estimate.std, estimate.n, []) function emptyIC(estimate::T, pval_threshold::Float64) where T <: EICEstimate @@ -63,7 +63,7 @@ function emptyIC(estimate::T, pval_threshold::Float64) where T <: EICEstimate return pval < pval_threshold ? estimate : emptyIC(estimate, nothing) end -emptyIC(estimate::T; pval_threshold=nothing) where T <: EICEstimate = emptyIC(estimate, pval_threshold) +emptyIC(estimate; pval_threshold=nothing) = emptyIC(estimate, pval_threshold) function Base.show(io::IO, ::MIME"text/plain", est::EICEstimate) diff --git a/src/estimands.jl b/src/estimands.jl index 2c442811..ffad634b 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -130,4 +130,4 @@ nuisance_functions_iterator(Ψ::ComposedEstimand) = Iterators.flatten(nuisance_functions_iterator(arg) for arg in Ψ.args) identify(method::AdjustmentMethod, Ψ::ComposedEstimand, scm::SCM) = - ComposedEstimand(Ψ.f, Tuple(identify(method, arg, scm) for arg ∈ Ψ.args)) \ No newline at end of file + ComposedEstimand(Ψ.f, Tuple(identify(method, arg, scm) for arg ∈ Ψ.args)) diff --git a/src/estimates.jl b/src/estimates.jl index 47c65020..fe51bbea 100644 --- a/src/estimates.jl +++ b/src/estimates.jl @@ -103,6 +103,8 @@ struct ComposedEstimate{T<:AbstractFloat} <: Estimate n::Int end +ComposedEstimate(;estimand, estimates, estimate, std, n) = ComposedEstimate(estimand, estimates, estimate, std, n) + function Base.show(io::IO, ::MIME"text/plain", est::ComposedEstimate) if length(est.std) !== 1 println(io, string("Estimate: ", estimate(est), "\nVariance: \n", var(est))) @@ -164,3 +166,8 @@ function HypothesisTests.OneSampleZTest(estimate::ComposedEstimate, Ψ₀=0) @assert length(estimate.estimate) == 1 "OneSampleTTest is only implemeted for real-valued statistics." return OneSampleZTest(estimate.estimate[1], sqrt(estimate.std[1]), estimate.n, Ψ₀) end + +function emptyIC(estimate::ComposedEstimate, pval_threshold) + emptied_estimates = Tuple(emptyIC(e, pval_threshold) for e in estimate.estimates) + ComposedEstimate(estimate.estimand, emptied_estimates, estimate.estimate, estimate.std, estimate.n) +end \ No newline at end of file diff --git a/test/composition.jl b/test/composition.jl index a9ad56b8..54effed6 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -200,6 +200,26 @@ end testres = OneSampleHotellingT2Test(jointEstimate) @test testres.x̄ ≈ jointEstimate.estimate @test pvalue(testres) < 1e-10 + + emptied_estimate = TMLE.emptyIC(jointEstimate) + for Ψ̂ₙ in emptied_estimate.estimates + @test Ψ̂ₙ.IC == [] + end + + pval_threshold = 1e-3 + maybe_emptied_estimate = TMLE.emptyIC(jointEstimate, pval_threshold=pval_threshold) + n_empty = 0 + for i in 1:3 + pval = pvalue(OneSampleTTest(jointEstimate.estimates[i])) + maybe_emptied_IC = maybe_emptied_estimate.estimates[i].IC + if pval > pval_threshold + @test maybe_emptied_IC == [] + n_empty += 1 + else + @test length(maybe_emptied_IC) == n + end + end + @test n_empty > 0 end From f301b910919137af549fce88b5dc00a0e792dc62 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 28 Nov 2023 10:51:30 +0000 Subject: [PATCH 141/151] rename field std to cov --- src/estimates.jl | 16 ++++++++-------- test/composition.jl | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/estimates.jl b/src/estimates.jl index fe51bbea..7b9f5120 100644 --- a/src/estimates.jl +++ b/src/estimates.jl @@ -99,14 +99,14 @@ struct ComposedEstimate{T<:AbstractFloat} <: Estimate estimand::ComposedEstimand estimates::Tuple estimate::Array{T} - std::Matrix{T} + cov::Matrix{T} n::Int end -ComposedEstimate(;estimand, estimates, estimate, std, n) = ComposedEstimate(estimand, estimates, estimate, std, n) +ComposedEstimate(;estimand, estimates, estimate, cov, n) = ComposedEstimate(estimand, estimates, estimate, cov, n) function Base.show(io::IO, ::MIME"text/plain", est::ComposedEstimate) - if length(est.std) !== 1 + if length(est.cov) !== 1 println(io, string("Estimate: ", estimate(est), "\nVariance: \n", var(est))) else testresult = OneSampleTTest(est) @@ -131,7 +131,7 @@ Distributions.estimate(est::ComposedEstimate) = Computes the estimated variance associated with the estimate. """ Statistics.var(est::ComposedEstimate) = - length(est.std) == 1 ? est.std[1] / est.n : est.std ./ est.n + length(est.cov) == 1 ? est.cov[1] / est.n : est.cov ./ est.n """ @@ -141,12 +141,12 @@ Performs a T test on the ComposedEstimate. """ function HypothesisTests.OneSampleTTest(estimate::ComposedEstimate, Ψ₀=0) @assert length(estimate.estimate) == 1 "OneSampleTTest is only implemeted for real-valued statistics." - return OneSampleTTest(estimate.estimate[1], sqrt(estimate.std[1]), estimate.n, Ψ₀) + return OneSampleTTest(estimate.estimate[1], sqrt(estimate.cov[1]), estimate.n, Ψ₀) end function HypothesisTests.OneSampleHotellingT2Test(estimate::ComposedEstimate, Ψ₀=zeros(size(estimate.estimate, 1))) x̄ = estimate.estimate - S = estimate.std + S = estimate.cov n, p = estimate.n, length(x̄) p == length(Ψ₀) || throw(DimensionMismatch("Number of variables does not match number of means")) @@ -164,10 +164,10 @@ Performs a T test on the ComposedEstimate. """ function HypothesisTests.OneSampleZTest(estimate::ComposedEstimate, Ψ₀=0) @assert length(estimate.estimate) == 1 "OneSampleTTest is only implemeted for real-valued statistics." - return OneSampleZTest(estimate.estimate[1], sqrt(estimate.std[1]), estimate.n, Ψ₀) + return OneSampleZTest(estimate.estimate[1], sqrt(estimate.cov[1]), estimate.n, Ψ₀) end function emptyIC(estimate::ComposedEstimate, pval_threshold) emptied_estimates = Tuple(emptyIC(e, pval_threshold) for e in estimate.estimates) - ComposedEstimate(estimate.estimand, emptied_estimates, estimate.estimate, estimate.std, estimate.n) + ComposedEstimate(estimate.estimand, emptied_estimates, estimate.estimate, estimate.cov, estimate.n) end \ No newline at end of file diff --git a/test/composition.jl b/test/composition.jl index 54effed6..66dda390 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -89,7 +89,7 @@ end composed_estimate, cache = tmle(composed_estimand, dataset; cache=cache, verbosity=0) @test composed_estimate.estimand == CM_result_composed_tmle.estimand @test CM_result_composed_tmle.estimate == composed_estimate.estimate - @test CM_result_composed_tmle.std == composed_estimate.std + @test CM_result_composed_tmle.cov == composed_estimate.cov @test CM_result_composed_tmle.n == composed_estimate.n # Via ATE ATE₁₀ = ATE( From 34c1a061edcd7e8dd83fccf1a1189381c7630faa Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 28 Nov 2023 11:13:33 +0000 Subject: [PATCH 142/151] relax some methods constraints --- src/adjustment.jl | 2 +- src/counterfactual_mean_based/estimands.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/adjustment.jl b/src/adjustment.jl index dbc8f412..b085a51c 100644 --- a/src/adjustment.jl +++ b/src/adjustment.jl @@ -18,7 +18,7 @@ function statistical_type_from_causal_type(T) return eval(Symbol(new_typestring)) end -identify(estimand, scm::SCM; method=BackdoorAdjustment()::BackdoorAdjustment) = +identify(estimand, scm; method=BackdoorAdjustment()) = identify(method, estimand, scm) diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index 270f623d..e94131ef 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -202,7 +202,7 @@ function to_dict(Ψ::T) where T <: StatisticalCMCompositeEstimand ) end -identify(method::AdjustmentMethod, Ψ::StatisticalCMCompositeEstimand, scm::SCM) = Ψ +identify(method, Ψ::StatisticalCMCompositeEstimand, scm) = Ψ function identify(method::BackdoorAdjustment, causal_estimand::T, scm::SCM) where T<:CausalCMCompositeEstimands # Treatment confounders From 3bd7796b5f1c27b24f1e0d090102b2b6f582f24e Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 29 Nov 2023 13:37:58 +0000 Subject: [PATCH 143/151] add SJON serialization of ComposedEstimates --- src/estimands.jl | 9 +++++++-- src/estimates.jl | 20 +++++++++++++++++--- src/estimators.jl | 11 +++++++++-- test/composition.jl | 30 +++++++++++++++++++++++++++--- 4 files changed, 60 insertions(+), 10 deletions(-) diff --git a/src/estimands.jl b/src/estimands.jl index ffad634b..7d9d4c1c 100644 --- a/src/estimands.jl +++ b/src/estimands.jl @@ -115,11 +115,16 @@ ComposedEstimand(f::String, args::AbstractVector) = ComposedEstimand(eval(Meta.p ComposedEstimand(;f, args) = ComposedEstimand(f, args) -to_dict(Ψ::ComposedEstimand) = Dict( +function to_dict(Ψ::ComposedEstimand) + fname = string(nameof(Ψ.f)) + startswith(fname, "#") && + throw(ArgumentError("The function of a ComposedEstimand cannot be anonymous to be converted to a dictionary.")) + return Dict( :type => string(ComposedEstimand), - :f => string(nameof(Ψ.f)), + :f => fname, :args => [to_dict(x) for x in Ψ.args] ) +end propensity_score_key(Ψ::ComposedEstimand) = Tuple(unique(Iterators.flatten(propensity_score_key(arg) for arg in Ψ.args))) outcome_mean_key(Ψ::ComposedEstimand) = Tuple(unique(outcome_mean_key(arg) for arg in Ψ.args)) diff --git a/src/estimates.jl b/src/estimates.jl index 7b9f5120..e91baab2 100644 --- a/src/estimates.jl +++ b/src/estimates.jl @@ -103,7 +103,12 @@ struct ComposedEstimate{T<:AbstractFloat} <: Estimate n::Int end -ComposedEstimate(;estimand, estimates, estimate, cov, n) = ComposedEstimate(estimand, estimates, estimate, cov, n) +to_matrix(x::Matrix) = x +to_matrix(x) = reduce(hcat, x) + +ComposedEstimate(;estimand, estimates, estimate, cov, n) = + ComposedEstimate(estimand, Tuple(estimates), estimate, to_matrix(cov), n) + function Base.show(io::IO, ::MIME"text/plain", est::ComposedEstimate) if length(est.cov) !== 1 @@ -133,7 +138,6 @@ Computes the estimated variance associated with the estimate. Statistics.var(est::ComposedEstimate) = length(est.cov) == 1 ? est.cov[1] / est.n : est.cov ./ est.n - """ OneSampleTTest(r::ComposedEstimate, Ψ₀=0) @@ -170,4 +174,14 @@ end function emptyIC(estimate::ComposedEstimate, pval_threshold) emptied_estimates = Tuple(emptyIC(e, pval_threshold) for e in estimate.estimates) ComposedEstimate(estimate.estimand, emptied_estimates, estimate.estimate, estimate.cov, estimate.n) -end \ No newline at end of file +end + + +to_dict(estimate::ComposedEstimate) = Dict( + :type => string(ComposedEstimate), + :estimand => to_dict(estimate.estimand), + :estimates => [to_dict(e) for e in estimate.estimates], + :estimate => estimate.estimate, + :cov => estimate.cov, + :n => estimate.n +) \ No newline at end of file diff --git a/src/estimators.jl b/src/estimators.jl index d971fb4a..dfa27633 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -97,13 +97,20 @@ Estimates all components of Ψ and then Ψ itself. """ function (estimator::Estimator)(Ψ::ComposedEstimand, dataset; cache=Dict(), verbosity=1, backend=AD.ZygoteBackend()) estimates = map(Ψ.args) do estimand - estimate, cache = estimator(estimand, dataset; cache=cache, verbosity=verbosity) + estimate, _ = estimator(estimand, dataset; cache=cache, verbosity=verbosity) estimate end f₀, σ₀, n = _compose(Ψ.f, estimates...; backend=backend) return ComposedEstimate(Ψ, estimates, f₀, σ₀, n), cache end +""" + joint_estimand(args...) + +This function is temporary and only necessary because of a bug in +the AD package. Simply call `vcat` in the future. +""" +joint_estimand(args...) = vcat(args...) """ compose(f, estimation_results::Vararg{EICEstimate, N}) where N @@ -157,7 +164,7 @@ function compose(f, estimates...; backend=AD.ZygoteBackend()) end function _compose(f, estimates...; backend=AD.ZygoteBackend()) - Σ = covariance_matrix(estimates...) + Σ = TMLE.covariance_matrix(estimates...) point_estimates = [r.estimate for r in estimates] f₀, Js = AD.value_and_jacobian(backend, f, point_estimates...) J = hcat(Js...) diff --git a/test/composition.jl b/test/composition.jl index 66dda390..4a6f9703 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -38,7 +38,7 @@ end @test Σ == cov(X) end -@testset "Test to_dict and from_dict!" begin +@testset "Test to_dict and from_dict! ComposedEstimand" begin ATE₁ = ATE( outcome=:Y, treatment_values = (T=(case=1, control=0),), @@ -53,6 +53,10 @@ end d = TMLE.to_dict(diff) diff_from_dict = TMLE.from_dict!(d) @test diff_from_dict == diff + + # Anonymous function will raise + anonymousdiff = ComposedEstimand((x,y) -> x - y, (ATE₁, ATE₂)) + @test_throws ArgumentError TMLE.to_dict(anonymousdiff) end @testset "Test composition CM(1) - CM(0) = ATE(1,0)" begin dataset = make_dataset(;n=1000) @@ -192,8 +196,7 @@ end treatment_values = (T₁=(case=3, control=2), T₂=(case=3, control=2)), treatment_confounders = (T₁ = [:W], T₂ = [:W]) ) - - jointIATE = ComposedEstimand((x, y, z) -> [x, y, z], (IATE₁, IATE₂, IATE₃)) + jointIATE = ComposedEstimand(TMLE.joint_estimand, (IATE₁, IATE₂, IATE₃)) ose = OSE(models=TMLE.default_models(G=LogisticClassifier(), Q_continuous=LinearRegressor())) jointEstimate, _ = ose(jointIATE, dataset, verbosity=0) @@ -220,6 +223,27 @@ end end end @test n_empty > 0 + + d = TMLE.to_dict(jointEstimate) + jointEstimate_fromdict = TMLE.from_dict!(d) + + @test jointEstimate.estimand == jointEstimate_fromdict.estimand + @test jointEstimate.cov == jointEstimate_fromdict.cov + @test jointEstimate.estimate == jointEstimate_fromdict.estimate + @test jointEstimate.n == jointEstimate_fromdict.n + @test length(jointEstimate_fromdict.estimates) == 3 + + if VERSION >= v"1.9" + using JSON + filename, _ = mktemp() + TMLE.write_json(filename, jointEstimate) + from_json = TMLE.read_json(filename) + @test jointEstimate.estimand == from_json.estimand + @test jointEstimate.cov == from_json.cov + @test jointEstimate.estimate == from_json.estimate + @test jointEstimate.n == from_json.n + @test length(jointEstimate_fromdict.estimates) == 3 + end end From 62e16a7adcb479c753444774599b8dba1a9f9185 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 29 Nov 2023 15:40:37 +0000 Subject: [PATCH 144/151] fix 0 dim array returned sometimes and --- src/configuration.jl | 2 +- src/estimates.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/configuration.jl b/src/configuration.jl index 09b1d95b..d767f813 100644 --- a/src/configuration.jl +++ b/src/configuration.jl @@ -36,7 +36,7 @@ Converts a dictionary to a TMLE struct. function from_dict!(d::Dict{T, Any}) where T haskey(d, T(:type)) || return Dict(key => from_dict!(val) for (key, val) in d) constructor = eval(Meta.parse(pop!(d, :type))) - return constructor(;[key => from_dict!(val) for (key, val) in d]...) + return constructor(;(key => from_dict!(val) for (key, val) in d)...) end function read_yaml end diff --git a/src/estimates.jl b/src/estimates.jl index e91baab2..51fc350e 100644 --- a/src/estimates.jl +++ b/src/estimates.jl @@ -107,7 +107,7 @@ to_matrix(x::Matrix) = x to_matrix(x) = reduce(hcat, x) ComposedEstimate(;estimand, estimates, estimate, cov, n) = - ComposedEstimate(estimand, Tuple(estimates), estimate, to_matrix(cov), n) + ComposedEstimate(estimand, Tuple(estimates), collect(estimate), to_matrix(cov), n) function Base.show(io::IO, ::MIME"text/plain", est::ComposedEstimate) From a67ab4d93a2c5a7f3afd0dc65b48d47b54e65882 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Fri, 1 Dec 2023 16:02:01 +0000 Subject: [PATCH 145/151] fix non working sorting on tuples --- src/counterfactual_mean_based/estimands.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/counterfactual_mean_based/estimands.jl b/src/counterfactual_mean_based/estimands.jl index e94131ef..e93e2d42 100644 --- a/src/counterfactual_mean_based/estimands.jl +++ b/src/counterfactual_mean_based/estimands.jl @@ -158,12 +158,13 @@ treatment_values(d::AbstractDict) = (;d...) treatment_values(d) = d get_treatment_specs(treatment_specs::NamedTuple{names, }) where names = - NamedTuple{sort(names)}(treatment_specs) + NamedTuple{Tuple(sort(collect(names)))}(treatment_specs) function get_treatment_specs(treatment_specs::NamedTuple{names, <:Tuple{Vararg{NamedTuple}}}) where names case_control = ((case=v[:case], control=v[:control]) for v in values(treatment_specs)) treatment_specs = (;zip(keys(treatment_specs), case_control)...) - return NamedTuple{sort(names)}(treatment_specs) + sorted_names = Tuple(sort(collect(names))) + return NamedTuple{sorted_names}(treatment_specs) end get_treatment_specs(treatment_specs::AbstractDict) = From b4872f579bd480b6fbbc2c00af78c3fd694cd777 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Mon, 4 Dec 2023 18:40:29 +0000 Subject: [PATCH 146/151] make config more flexible --- src/configuration.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/configuration.jl b/src/configuration.jl index d767f813..9e1aaad3 100644 --- a/src/configuration.jl +++ b/src/configuration.jl @@ -1,5 +1,5 @@ struct Configuration - estimands::AbstractVector{<:Estimand} + estimands::AbstractVector scm::Union{Nothing, SCM} adjustment::Union{Nothing, <:AdjustmentMethod} end From c2f143dda45a62cb2afe89b1ac8754bd1446cec4 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Tue, 5 Dec 2023 11:22:00 +0000 Subject: [PATCH 147/151] reexport OneSampleHotellingT2Test --- src/TMLE.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index 303af44f..85d63980 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -30,7 +30,7 @@ export CM, ATE, IATE export AVAILABLE_ESTIMANDS export TMLEE, OSE, NAIVE export ComposedEstimand -export var, estimate, OneSampleTTest, OneSampleZTest, pvalue, confint, emptyIC +export var, estimate, OneSampleTTest, OneSampleZTest, OneSampleHotellingT2Test,pvalue, confint, emptyIC export compose export TreatmentTransformer, with_encoder, encoder export BackdoorAdjustment, identify From 294f9fcd5d58fa1dac4fc961029231d3184c811a Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 6 Dec 2023 16:10:53 +0000 Subject: [PATCH 148/151] add kwargs to read_json and yaml to fix windows pb --- Project.toml | 14 ++++---------- ext/JSONExt.jl | 2 +- ext/YAMLExt.jl | 2 +- test/composition.jl | 2 +- test/configuration.jl | 6 +++--- .../non_regression_test.jl | 2 +- 6 files changed, 11 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index 882fb998..76f023b7 100644 --- a/Project.toml +++ b/Project.toml @@ -27,15 +27,12 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [weakdeps] -CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" -YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" [extensions] -GraphMakieExt = ["GraphMakie", "CairoMakie"] -YAMLExt = "YAML" JSONExt = "JSON" +YAMLExt = "YAML" [compat] AbstractDifferentiation = "0.6.0" @@ -45,6 +42,7 @@ Distributions = "0.25" GLM = "1.8.2" Graphs = "1.8" HypothesisTests = "0.10, 0.11" +JSON = "0.21.4" LogExpFunctions = "0.3" MLJBase = "1.0.1" MLJGLMInterface = "0.3.4" @@ -56,13 +54,9 @@ PrettyTables = "2.2" TableOperations = "1.2" Tables = "1.6" YAML = "0.4.9" -JSON = "0.21.4" Zygote = "0.6" julia = "1.6, 1.7, 1" [extras] -CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" -GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2" -YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" - +YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" diff --git a/ext/JSONExt.jl b/ext/JSONExt.jl index d1f49082..c53ecb51 100644 --- a/ext/JSONExt.jl +++ b/ext/JSONExt.jl @@ -11,7 +11,7 @@ Loads a YAML configuration file containing: - An SCM (optional) - An Adjustment Method (optional) """ -TMLE.read_json(file) = TMLE.from_dict!(JSON.parsefile(file, dicttype=Dict{Symbol, Any})) +TMLE.read_json(file; dicttype=Dict{Symbol, Any}, kwargs...) = TMLE.from_dict!(JSON.parsefile(file; dicttype=dicttype, kwargs...)) """ write_json(file, config::Configuration) diff --git a/ext/YAMLExt.jl b/ext/YAMLExt.jl index 1a06db27..d566ae6b 100644 --- a/ext/YAMLExt.jl +++ b/ext/YAMLExt.jl @@ -11,7 +11,7 @@ Loads a YAML configuration file containing: - An SCM (optional) - An Adjustment Method (optional) """ -TMLE.read_yaml(file) = TMLE.from_dict!(YAML.load_file(file, dicttype=Dict{Symbol, Any})) +TMLE.read_yaml(file; dicttype=Dict{Symbol, Any}, kwargs...) = TMLE.from_dict!(YAML.load_file(file, dicttype=dicttype, kwargs...)) """ write_yaml(file, config::Configuration) diff --git a/test/composition.jl b/test/composition.jl index 4a6f9703..c1babcc1 100644 --- a/test/composition.jl +++ b/test/composition.jl @@ -237,7 +237,7 @@ end using JSON filename, _ = mktemp() TMLE.write_json(filename, jointEstimate) - from_json = TMLE.read_json(filename) + from_json = TMLE.read_json(filename, use_mmap=false) @test jointEstimate.estimand == from_json.estimand @test jointEstimate.cov == from_json.cov @test jointEstimate.estimate == from_json.estimate diff --git a/test/configuration.jl b/test/configuration.jl index 01e43ae9..fae25097 100644 --- a/test/configuration.jl +++ b/test/configuration.jl @@ -56,7 +56,7 @@ end @test estimands_from_yaml == estimands TMLE.write_json(jsonfilename, estimands) - estimands_from_json = TMLE.read_json(jsonfilename) + estimands_from_json = TMLE.read_json(jsonfilename, use_mmap=false) @test estimands_from_json == estimands configuration = Configuration( @@ -66,7 +66,7 @@ end TMLE.write_yaml(yamlfilename, configuration) TMLE.write_json(jsonfilename, configuration) config_from_yaml = TMLE.read_yaml(yamlfilename) - config_from_json = TMLE.read_json(jsonfilename) + config_from_json = TMLE.read_json(jsonfilename, use_mmap=false) for loaded_config in (config_from_yaml, config_from_json) @test loaded_config.scm === nothing @test loaded_config.adjustment === nothing @@ -113,7 +113,7 @@ end TMLE.write_json(jsonfilename, configuration) config_from_yaml = TMLE.read_yaml(yamlfilename) - config_from_json = TMLE.read_json(jsonfilename) + config_from_json = TMLE.read_json(jsonfilename, use_mmap=false) for loaded_config in (config_from_yaml, config_from_json) scm = loaded_config.scm diff --git a/test/counterfactual_mean_based/non_regression_test.jl b/test/counterfactual_mean_based/non_regression_test.jl index 646e5476..c3ef7175 100644 --- a/test/counterfactual_mean_based/non_regression_test.jl +++ b/test/counterfactual_mean_based/non_regression_test.jl @@ -53,7 +53,7 @@ end if VERSION >= v"1.9" jsonfile = mktemp()[1] TMLE.write_json(jsonfile, [tmle_result]) - results_from_json = TMLE.read_json(jsonfile) + results_from_json = TMLE.read_json(jsonfile, use_mmap=false) regression_tests(results_from_json[1]) yamlfile = mktemp()[1] From ba6f969d57df4ddf48f8657b6f98263e93275d2b Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 6 Dec 2023 16:24:46 +0000 Subject: [PATCH 149/151] remove TMLE prefix in code --- .../clever_covariate.jl | 2 +- src/counterfactual_mean_based/estimators.jl | 42 +++++++++---------- src/counterfactual_mean_based/fluctuation.jl | 2 +- src/counterfactual_mean_based/gradient.jl | 8 ++-- src/estimators.jl | 6 +-- src/utils.jl | 2 +- 6 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/counterfactual_mean_based/clever_covariate.jl b/src/counterfactual_mean_based/clever_covariate.jl index d3d2e2cc..4ac92e77 100644 --- a/src/counterfactual_mean_based/clever_covariate.jl +++ b/src/counterfactual_mean_based/clever_covariate.jl @@ -58,7 +58,7 @@ function clever_covariate_and_weights( weighted_fluctuation=false ) # Compute the indicator values - T = TMLE.selectcols(dataset, (p.estimand.outcome for p in Gs)) + T = selectcols(dataset, (p.estimand.outcome for p in Gs)) indic_vals = indicator_values(indicator_fns(Ψ), T) weights = balancing_weights(Gs, dataset; ps_lowerbound=ps_lowerbound) if weighted_fluctuation diff --git a/src/counterfactual_mean_based/estimators.jl b/src/counterfactual_mean_based/estimators.jl index 7bb17b89..fb768c00 100644 --- a/src/counterfactual_mean_based/estimators.jl +++ b/src/counterfactual_mean_based/estimators.jl @@ -29,7 +29,7 @@ key(estimator::CMRelevantFactorsEstimator) = get_train_validation_indices(resampling::Nothing, factors, dataset) = nothing function get_train_validation_indices(resampling::ResamplingStrategy, factors, dataset) - relevant_columns = collect(TMLE.variables(factors)) + relevant_columns = collect(variables(factors)) outcome_variable = factors.outcome_mean.outcome feature_variables = filter(x -> x !== outcome_variable, relevant_columns) return Tuple(MLJBase.train_test_pairs( @@ -63,7 +63,7 @@ function (estimator::CMRelevantFactorsEstimator)(estimand, dataset; cache=Dict() verbosity > 0 && @info(fit_string(estimand)) models = estimator.models # Get train validation indices - train_validation_indices = TMLE.get_train_validation_indices(estimator.resampling, estimand, dataset) + train_validation_indices = get_train_validation_indices(estimator.resampling, estimand, dataset) # Fit propensity score propensity_score_estimate = map(estimand.propensity_score) do factor try @@ -106,7 +106,7 @@ struct TargetedCMRelevantFactorsEstimator end TargetedCMRelevantFactorsEstimator(Ψ, initial_factors_estimate; tol=nothing, ps_lowerbound=1e-8, weighted=false, machine_cache=false) = - TargetedCMRelevantFactorsEstimator(TMLE.Fluctuation(Ψ, initial_factors_estimate; + TargetedCMRelevantFactorsEstimator(Fluctuation(Ψ, initial_factors_estimate; tol=tol, ps_lowerbound=ps_lowerbound, weighted=weighted, @@ -176,12 +176,12 @@ TMLEE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, weighted function (tmle::TMLEE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1) # Check the estimand against the dataset - TMLE.check_treatment_levels(Ψ, dataset) + check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's relevant factors - relevant_factors = TMLE.get_relevant_factors(Ψ) - nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(relevant_factors)) - initial_factors_dataset = TMLE.choose_initial_dataset(dataset, nomissing_dataset, tmle.resampling) - initial_factors_estimator = TMLE.CMRelevantFactorsEstimator(tmle.resampling, tmle.models) + relevant_factors = get_relevant_factors(Ψ) + nomissing_dataset = nomissing(dataset, variables(relevant_factors)) + initial_factors_dataset = choose_initial_dataset(dataset, nomissing_dataset, tmle.resampling) + initial_factors_estimator = CMRelevantFactorsEstimator(tmle.resampling, tmle.models) initial_factors_estimate = initial_factors_estimator(relevant_factors, initial_factors_dataset; cache=cache, verbosity=verbosity, @@ -189,10 +189,10 @@ function (tmle::TMLEE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict() ) # Get propensity score truncation threshold n = nrows(nomissing_dataset) - ps_lowerbound = TMLE.ps_lower_bound(n, tmle.ps_lowerbound) + ps_lowerbound = ps_lower_bound(n, tmle.ps_lowerbound) # Fluctuation initial factors verbosity >= 1 && @info "Performing TMLE..." - targeted_factors_estimator = TMLE.TargetedCMRelevantFactorsEstimator( + targeted_factors_estimator = TargetedCMRelevantFactorsEstimator( Ψ, initial_factors_estimate; tol=tmle.tol, @@ -206,7 +206,7 @@ function (tmle::TMLEE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict() machine_cache=tmle.machine_cache ) # Estimation results after TMLE - IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, targeted_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) + IC, Ψ̂ = gradient_and_estimate(Ψ, targeted_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) σ̂ = std(IC) n = size(IC, 1) verbosity >= 1 && @info "Done." @@ -254,12 +254,12 @@ OSE(;models=default_models(), resampling=nothing, ps_lowerbound=1e-8, machine_ca function (estimator::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1) # Check the estimand against the dataset - TMLE.check_treatment_levels(Ψ, dataset) + check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's relevant factors - initial_factors = TMLE.get_relevant_factors(Ψ) - nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(initial_factors)) - initial_factors_dataset = TMLE.choose_initial_dataset(dataset, nomissing_dataset, estimator.resampling) - initial_factors_estimator = TMLE.CMRelevantFactorsEstimator(estimator.resampling, estimator.models) + initial_factors = get_relevant_factors(Ψ) + nomissing_dataset = nomissing(dataset, variables(initial_factors)) + initial_factors_dataset = choose_initial_dataset(dataset, nomissing_dataset, estimator.resampling) + initial_factors_estimator = CMRelevantFactorsEstimator(estimator.resampling, estimator.models) initial_factors_estimate = initial_factors_estimator( initial_factors, initial_factors_dataset; @@ -268,10 +268,10 @@ function (estimator::OSE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dic ) # Get propensity score truncation threshold n = nrows(nomissing_dataset) - ps_lowerbound = TMLE.ps_lower_bound(n, estimator.ps_lowerbound) + ps_lowerbound = ps_lower_bound(n, estimator.ps_lowerbound) # Gradient and estimate - IC, Ψ̂ = TMLE.gradient_and_estimate(Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) + IC, Ψ̂ = gradient_and_estimate(Ψ, initial_factors_estimate, nomissing_dataset; ps_lowerbound=ps_lowerbound) IC_mean = mean(IC) IC .-= IC_mean σ̂ = std(IC) @@ -290,10 +290,10 @@ end function (estimator::NAIVE)(Ψ::StatisticalCMCompositeEstimand, dataset; cache=Dict(), verbosity=1) # Check the estimand against the dataset - TMLE.check_treatment_levels(Ψ, dataset) + check_treatment_levels(Ψ, dataset) # Initial fit of the SCM's relevant factors - relevant_factors = TMLE.get_relevant_factors(Ψ) - nomissing_dataset = TMLE.nomissing(dataset, TMLE.variables(relevant_factors)) + relevant_factors = get_relevant_factors(Ψ) + nomissing_dataset = nomissing(dataset, variables(relevant_factors)) outcome_mean_estimate = MLConditionalDistributionEstimator(estimator.model)( relevant_factors.outcome_mean, dataset; diff --git a/src/counterfactual_mean_based/fluctuation.jl b/src/counterfactual_mean_based/fluctuation.jl index b693d47c..786c60dd 100644 --- a/src/counterfactual_mean_based/fluctuation.jl +++ b/src/counterfactual_mean_based/fluctuation.jl @@ -1,6 +1,6 @@ mutable struct Fluctuation <: MLJBase.Supervised Ψ::StatisticalCMCompositeEstimand - initial_factors::TMLE.MLCMRelevantFactors + initial_factors::MLCMRelevantFactors tol::Union{Nothing, Float64} ps_lowerbound::Float64 weighted::Bool diff --git a/src/counterfactual_mean_based/gradient.jl b/src/counterfactual_mean_based/gradient.jl index ff715039..882e325e 100644 --- a/src/counterfactual_mean_based/gradient.jl +++ b/src/counterfactual_mean_based/gradient.jl @@ -16,7 +16,7 @@ function counterfactual_aggregate(Ψ::StatisticalCMCompositeEstimand, Q, dataset # Loop over Treatment settings for (vals, sign) in indicator_fns(Ψ) # Counterfactual dataset for a given treatment setting - T_ct = TMLE.counterfactualTreatment(vals, Ttemplate) + T_ct = counterfactualTreatment(vals, Ttemplate) X_ct = merge(X, T_ct) # Counterfactual mean ctf_agg .+= sign .* expected_value(Q, X_ct) @@ -56,9 +56,9 @@ end function gradient_and_estimate(Ψ::StatisticalCMCompositeEstimand, factors, dataset; ps_lowerbound=1e-8) Q = factors.outcome_mean G = factors.propensity_score - ctf_agg = TMLE.counterfactual_aggregate(Ψ, Q, dataset) - Ψ̂ = TMLE.compute_estimate(ctf_agg, TMLE.train_validation_indices_from_factors(factors)) - IC = TMLE.∇YX(Ψ, Q, G, dataset; ps_lowerbound = ps_lowerbound) .+ TMLE.∇W(ctf_agg, Ψ̂) + ctf_agg = counterfactual_aggregate(Ψ, Q, dataset) + Ψ̂ = compute_estimate(ctf_agg, train_validation_indices_from_factors(factors)) + IC = ∇YX(Ψ, Q, G, dataset; ps_lowerbound = ps_lowerbound) .+ ∇W(ctf_agg, Ψ̂) return IC, Ψ̂ end diff --git a/src/estimators.jl b/src/estimators.jl index dfa27633..a7fb1cbb 100644 --- a/src/estimators.jl +++ b/src/estimators.jl @@ -19,7 +19,7 @@ function (estimator::MLConditionalDistributionEstimator)(estimand, dataset; cach end verbosity > 0 && @info(string("Estimating: ", string_repr(estimand))) # Otherwise estimate - relevant_dataset = TMLE.nomissing(dataset, TMLE.variables(estimand)) + relevant_dataset = nomissing(dataset, variables(estimand)) # Fit Conditional DIstribution using MLJ X = selectcols(relevant_dataset, estimand.parents) y = Tables.getcolumn(relevant_dataset, estimand.outcome) @@ -57,7 +57,7 @@ function (estimator::SampleSplitMLConditionalDistributionEstimator)(estimand, da # Otherwise estimate verbosity > 0 && @info(string("Estimating: ", string_repr(estimand))) - relevant_dataset = TMLE.selectcols(dataset, TMLE.variables(estimand)) + relevant_dataset = selectcols(dataset, variables(estimand)) nfolds = size(estimator.train_validation_indices, 1) machines = Vector{Machine}(undef, nfolds) # Fit Conditional Distribution on each training split using MLJ @@ -164,7 +164,7 @@ function compose(f, estimates...; backend=AD.ZygoteBackend()) end function _compose(f, estimates...; backend=AD.ZygoteBackend()) - Σ = TMLE.covariance_matrix(estimates...) + Σ = covariance_matrix(estimates...) point_estimates = [r.estimate for r in estimates] f₀, Js = AD.value_and_jacobian(backend, f, point_estimates...) J = hcat(Js...) diff --git a/src/utils.jl b/src/utils.jl index 9afb9dca..a588cebd 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -63,7 +63,7 @@ end last_fluctuation(cache) = cache[:last_fluctuation] function last_fluctuation_epsilon(cache) - mach = TMLE.last_fluctuation(cache).outcome_mean.machine + mach = last_fluctuation(cache).outcome_mean.machine fp = fitted_params(fitted_params(mach).fitresult.one_dimensional_path) return fp.coef end From b4083fbd782f63ef9575f9fe68eeb6cecf3a5e11 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 6 Dec 2023 16:27:33 +0000 Subject: [PATCH 150/151] remove precompilation workload for now --- src/TMLE.jl | 104 ---------------------------------------------------- 1 file changed, 104 deletions(-) diff --git a/src/TMLE.jl b/src/TMLE.jl index 85d63980..3c869b3b 100644 --- a/src/TMLE.jl +++ b/src/TMLE.jl @@ -60,109 +60,5 @@ include("counterfactual_mean_based/gradient.jl") include("configuration.jl") -# ############################################################################# -# PRECOMPILATION WORKLOAD -# ############################################################################# - -function run_precompile_workload() - @setup_workload begin - # Putting some things in `@setup_workload` instead of `@compile_workload` can reduce the size of the - # precompile file and potentially make loading faster. - n = 1000 - C₁ = rand(n) - W₁ = rand(n) - W₂ = rand(n) - μT₁ = logistic.(1 .+ W₁ .- W₂) - T₁ = categorical(rand(n) .< μT₁) - μT₂ = logistic.(1 .+ W₁ .+ 2W₂) - T₂ = categorical(rand(n) .< μT₂) - μY = 1 .+ float(T₁) .+ 2W₂ .- C₁ - Y₁ = μY .+ rand(n) - Y₂ = categorical(rand(n) .< logistic.(μY)) - - dataset = (C₁=C₁, W₁=W₁, W₂=W₂, T₁=T₁, T₂=T₂, Y₁=Y₁, Y₂=Y₂) - - @compile_workload begin - # SCM constructors - ## Incremental - scm = SCM() - push!(scm, SE(:Y₁, [:T₁, :W₁, :W₂], model=LinearRegressor())) - push!(scm, SE(:T₁, [:W₁, :W₂])) - setmodel!(scm.T₁, LinearBinaryClassifier()) - ## Implicit through estimand - for estimand_type in [CM, ATE, IATE] - estimand_type( - outcome=:Y₁, - treatment=(T₁=true,), - confounders=[:W₁, :W₂], - outcome_model=LinearRegressor() - ) - end - ## Complete - # Not using the `with_encoder` for now because it crashes precompilation - # and fit can still happen somehow. - scm = SCM( - SE(:Y₁, [:T₁, :W₁, :W₂], model=LinearRegressor()), - SE(:T₁, [:W₁, :W₂],model=LinearBinaryClassifier()), - SE(:Y₂, [:T₁, :T₂, :W₁, :W₂, :C₁], model=LinearBinaryClassifier()), - SE(:T₂, [:W₁, :W₂],model=LinearBinaryClassifier()), - ) - - # Estimate some parameters - Ψ₁ = CM( - scm, - outcome =:Y₁, - treatment=(T₁=true,), - ) - tmle_result₁, fluctuation = tmle!(Ψ₁, dataset, verbosity=0) - OneSampleTTest(tmle_result₁) - OneSampleZTest(tmle_result₁) - - ose_result₁, fluctuation = ose!(Ψ₁, dataset, verbosity=0) - OneSampleTTest(ose_result₁) - OneSampleZTest(ose_result₁) - - naive_plugin_estimate(Ψ₁) - # ATE - Ψ₂ = ATE( - scm, - outcome=:Y₂, - treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)) - ) - tmle_result₂, fluctuation = tmle!(Ψ₂, dataset, verbosity=0) - OneSampleTTest(tmle_result₂) - OneSampleZTest(tmle_result₂) - - ose_result₂, fluctuation = ose!(Ψ₂, dataset, verbosity=0) - OneSampleTTest(ose_result₂) - OneSampleZTest(ose_result₂) - - naive_plugin_estimate(Ψ₂) - # IATE - Ψ₃ = IATE( - scm, - outcome=:Y₂, - treatment=(T₁=(case=true, control=false), T₂=(case=true, control=false)) - ) - tmle_result₃, fluctuation = tmle!(Ψ₃, dataset, verbosity=0) - OneSampleTTest(tmle_result₃) - OneSampleZTest(tmle_result₃) - - ose_result₃, fluctuation = ose!(Ψ₃, dataset, verbosity=0) - OneSampleTTest(ose_result₃) - OneSampleZTest(ose_result₃) - - naive_plugin_estimate(Ψ₃) - # Composition - composed_result = compose((x,y) -> x - y, tmle_result₂, tmle_result₁) - OneSampleTTest(composed_result) - OneSampleZTest(composed_result) - - end - end - -end - -# run_precompile_workload() end From 9c05e87f047120045d8ce7f9f6829a730c7995a4 Mon Sep 17 00:00:00 2001 From: Olivier Labayle Date: Wed, 6 Dec 2023 16:40:19 +0000 Subject: [PATCH 151/151] up deps and remove unfortunate R file --- Project.toml | 2 +- sandbox.R | 53 ----------------------------------------------- test/Project.toml | 2 +- 3 files changed, 2 insertions(+), 55 deletions(-) delete mode 100644 sandbox.R diff --git a/Project.toml b/Project.toml index 76f023b7..ea971f96 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ LogExpFunctions = "0.3" MLJBase = "1.0.1" MLJGLMInterface = "0.3.4" MLJModels = "0.15, 0.16" -MetaGraphsNext = "0.6" +MetaGraphsNext = "0.7" Missings = "1.0" PrecompileTools = "1.1.1" PrettyTables = "2.2" diff --git a/sandbox.R b/sandbox.R deleted file mode 100644 index 01387d93..00000000 --- a/sandbox.R +++ /dev/null @@ -1,53 +0,0 @@ -library(data.table) -library(tmle3) -library(sl3) - -data = read.csv("/Users/olivierlabayle/Dev/TARGENE/TMLE.jl/test/data/perinatal.csv") - -node_list <- list( - W = c( - "apgar1", "apgar5", "gagebrth", "mage", "meducyrs", "sexn" - ), - A = "parity01", - Y = "haz01" -) - -glm = Lrnr_glm$new() -lrn_mean = Lrnr_mean$new() -sl <- Lrnr_sl$new(learners = Stack$new(glm, lrn_mean), metalearner = Lrnr_nnls$new()) - -learner_list <- list(A = glm, Y = glm) -# learner_list = list(A=sl, Y = sl) - -ate_spec <- tmle_ATE( - treatment_level = 1, - control_level = 0 -) - -tmle_task <- ate_spec$make_tmle_task(data, node_list) -initial_likelihood <- ate_spec$make_initial_likelihood( - tmle_task, - learner_list -) - -targeted_likelihood_cv <- Targeted_Likelihood$new(initial_likelihood) - -targeted_likelihood_no_cv <- - Targeted_Likelihood$new(initial_likelihood, - updater = list(cvtmle = FALSE) - ) - -tmle_params_cv <- ate_spec$make_params(tmle_task, targeted_likelihood_cv) -tmle_params_no_cv <- ate_spec$make_params(tmle_task, targeted_likelihood_no_cv) - -tmle_no_cv <- fit_tmle3( - tmle_task, targeted_likelihood_no_cv, tmle_params_no_cv, - targeted_likelihood_no_cv$updater -) -tmle_no_cv - -tmle_cv <- fit_tmle3( - tmle_task, targeted_likelihood_cv, tmle_params_cv, - targeted_likelihood_cv$updater -) -tmle_cv \ No newline at end of file diff --git a/test/Project.toml b/test/Project.toml index e0d9f184..e175e2de 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -25,7 +25,7 @@ YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6" [compat] CSV = "0.10" DataFrames = "1.5" -MLJLinearModels = "0.9" +MLJLinearModels = "0.10" StableRNGs = "1.0" StatisticalMeasures = "0.1.3" StatsBase = "0.33"