Skip to content

Commit

Permalink
try replace namedtuples in structs by ordereddicts
Browse files Browse the repository at this point in the history
  • Loading branch information
olivierlabayle committed Aug 2, 2024
1 parent 4fbd86a commit dbb49bd
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 109 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.16.1"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -18,6 +19,7 @@ MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MetaGraphsNext = "fa8bd995-216d-47f1-8a91-f3b68fbeb377"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
Expand Down Expand Up @@ -55,6 +57,8 @@ TableOperations = "1.2"
Tables = "1.6"
YAML = "0.4.9"
Zygote = "0.6.69"
OrderedCollections = "1.6.3"
AutoHashEquals = "2.1.0"
julia = "1.6, 1.7, 1"

[extras]
Expand Down
2 changes: 2 additions & 0 deletions src/TMLE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ using Graphs
using MetaGraphsNext
using Combinatorics
using SplitApplyCombine
using OrderedCollections
using AutoHashEquals

# #############################################################################
# EXPORTS
Expand Down
4 changes: 2 additions & 2 deletions src/configuration.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ from_dict!(x) = x
from_dict!(v::AbstractVector) = [from_dict!(x) for x in v]

"""
from_dict!(d::Dict)
from_dict!(d::AbstractDict)
Converts a dictionary to a TMLE struct.
"""
function from_dict!(d::Dict{T, Any}) where T
function from_dict!(d::AbstractDict{T, Any}) where T
haskey(d, T(:type)) || return Dict(key => from_dict!(val) for (key, val) in d)
constructor = eval(Meta.parse(pop!(d, :type)))
return constructor(;(key => from_dict!(val) for (key, val) in d)...)
Expand Down
94 changes: 42 additions & 52 deletions src/counterfactual_mean_based/estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ for (estimand, (formula,)) ∈ ESTIMANDS_DOCS
statistical_estimand = Symbol(:Statistical, estimand)
ex = quote
# Causal Estimand
struct $(causal_estimand) <: Estimand
@auto_hash_equals struct $(causal_estimand) <: Estimand
outcome::Symbol
treatment_values::NamedTuple
treatment_values::OrderedDict

function $(causal_estimand)(outcome, treatment_values)
outcome = Symbol(outcome)
Expand All @@ -52,17 +52,16 @@ for (estimand, (formula,)) ∈ ESTIMANDS_DOCS
end
end
# Statistical Estimand
struct $(statistical_estimand) <: Estimand
@auto_hash_equals struct $(statistical_estimand) <: Estimand
outcome::Symbol
treatment_values::NamedTuple
treatment_confounders::NamedTuple
treatment_values::OrderedDict
treatment_confounders::OrderedDict
outcome_extra_covariates::Tuple{Vararg{Symbol}}

function $(statistical_estimand)(outcome, treatment_values, treatment_confounders, outcome_extra_covariates)
outcome = Symbol(outcome)
treatment_values = get_treatment_specs(treatment_values)
treatment_variables = Tuple(keys(treatment_values))
treatment_confounders = NamedTuple{treatment_variables}([confounders_values(treatment_confounders, T) for T treatment_variables])
treatment_confounders = OrderedDict(T => confounders_values(treatment_confounders, T) for T (keys(treatment_values)))
outcome_extra_covariates = unique_sorted_tuple(outcome_extra_covariates)
return new(outcome, treatment_values, treatment_confounders, outcome_extra_covariates)
end
Expand Down Expand Up @@ -96,30 +95,31 @@ StatisticalCMCompositeEstimand = Union{(eval(Symbol(:Statistical, x)) for x in k

const AVAILABLE_ESTIMANDS = [x[1] for x ESTIMANDS_DOCS]

indicator_fns::StatisticalCM) = Dict(values.treatment_values) => 1.)
indicator_fns::StatisticalCM) = Dict(Tuple(values.treatment_values)) => 1.)

function indicator_fns::StatisticalATE)
case = []
control = []
for treatment in Ψ.treatment_values
for treatment in values(Ψ.treatment_values)
push!(case, treatment.case)
push!(control, treatment.control)
end
return Dict(Tuple(case) => 1., Tuple(control) => -1.)
end

ncases(value, Ψ::StatisticalIATE) = sum(value[i] == Ψ.treatment_values[i].case for i in eachindex(value))
ncases(counterfactual_values, treatments_cases) = sum(counterfactual_values .== treatments_cases)

function indicator_fns::StatisticalIATE)
N = length(treatments(Ψ))
N = length(Ψ.treatment_values)
key_vals = Pair[]
for cf in Iterators.product((values.treatment_values[T]) for T in treatments(Ψ))...)
push!(key_vals, cf => float((-1)^(N - ncases(cf, Ψ))))
treatments_cases = Tuple(case_control.case for case_control values.treatment_values))
for cf_values in Iterators.product((values(case_control_nt) for case_control_nt in values.treatment_values))...)
push!(key_vals, cf_values => float((-1)^(N - ncases(cf_values, treatments_cases))))
end
return Dict(key_vals...)
end

outcome_mean::StatisticalCMCompositeEstimand) = ExpectedValue.outcome, Tuple(union.outcome_extra_covariates, keys.treatment_confounders), (Ψ.treatment_confounders)...)))
outcome_mean::StatisticalCMCompositeEstimand) = ExpectedValue.outcome, Tuple(union.outcome_extra_covariates, keys.treatment_confounders), values.treatment_confounders)...)))

outcome_mean_key::StatisticalCMCompositeEstimand) = variables(outcome_mean(Ψ))

Expand Down Expand Up @@ -148,36 +148,27 @@ function Base.show(io::IO, ::MIME"text/plain", Ψ::T) where T <: StatisticalCMCo
println(io, param_string)
end

function treatment_specs_to_dict(treatment_values::NamedTuple{T, <:Tuple{Vararg{NamedTuple}}}) where T
Dict(key => Dict(pairs(vals)) for (key, vals) in pairs(treatment_values))
end
case_control_dict(case_control_nt::NamedTuple) = OrderedDict(pairs(case_control_nt))
case_control_dict(value) = value

treatment_specs_to_dict(treatment_values::NamedTuple) = Dict(pairs(treatment_values))
treatment_specs_to_dict(treatment_values) = OrderedDict(key => case_control_dict(case_control_nt) for (key, case_control_nt) in treatment_values)

treatment_values(d::AbstractDict) = (;d...)
treatment_values(d) = d

confounders_values(key_value_iterable::Union{NamedTuple, Dict}, T) = unique_sorted_tuple(key_value_iterable[T])
confounders_values(key_value_iterable::Union{NamedTuple, AbstractDict}, key) = unique_sorted_tuple(key_value_iterable[key])

confounders_values(iterable, T) = unique_sorted_tuple(iterable)
confounders_values(iterable, key) = unique_sorted_tuple(iterable)

get_treatment_specs(treatment_specs::NamedTuple{names, }) where names =
NamedTuple{Tuple(sort(collect(names)))}(treatment_specs)
confounders_to_dict(treatment_confounders) = Dict(key => collect(values) for (key, values) in treatment_confounders)

function get_treatment_specs(treatment_specs::NamedTuple{names, <:Tuple{Vararg{NamedTuple}}}) where names
case_control = ((case=v[:case], control=v[:control]) for v in values(treatment_specs))
treatment_specs = (;zip(keys(treatment_specs), case_control)...)
sorted_names = Tuple(sort(collect(names)))
return NamedTuple{sorted_names}(treatment_specs)
end

get_treatment_specs(treatment_specs::AbstractDict) =
get_treatment_specs((;(key => treatment_values(val) for (key, val) in treatment_specs)...))
case_control_to_nt(scalar) = scalar

constructorname(T; prefix="TMLE.Causal") = replace(string(T), prefix => "")
case_control_to_nt(case_control_iter::Union{NamedTuple, AbstractDict}) = (control=case_control_iter[:control], case=case_control_iter[:case])

treatment_confounders_to_dict(treatment_confounders::NamedTuple) =
Dict(key => collect(vals) for (key, vals) in pairs(treatment_confounders))
get_treatment_specs(key_value_iterable) = sort(OrderedDict(Symbol(key) => case_control_to_nt(case_control_iter) for (key, case_control_iter) pairs(key_value_iterable)))

constructorname(T; prefix="TMLE.Causal") = replace(string(T), prefix => "")

"""
to_dict(Ψ::T) where T <: CausalCMCompositeEstimands
Expand All @@ -202,7 +193,7 @@ function to_dict(Ψ::T) where T <: StatisticalCMCompositeEstimand
:type => constructorname(T; prefix="TMLE.Statistical"),
:outcome => Ψ.outcome,
:treatment_values => treatment_specs_to_dict.treatment_values),
:treatment_confounders => treatment_confounders_to_dict.treatment_confounders),
:treatment_confounders => confounders_to_dict.treatment_confounders),
:outcome_extra_covariates => collect.outcome_extra_covariates)
)
end
Expand All @@ -211,13 +202,10 @@ identify(method, Ψ::StatisticalCMCompositeEstimand, scm) = Ψ

function identify(method::BackdoorAdjustment, causal_estimand::T, scm::SCM) where T<:CausalCMCompositeEstimands
# Treatment confounders
treatment_names = keys(causal_estimand.treatment_values)
treatment_names = collect(keys(causal_estimand.treatment_values))
treatment_codes = [code_for(scm.graph, treatment) for treatment treatment_names]
confounders_codes = scm.graph.graph.badjlist[treatment_codes]
treatment_confounders = NamedTuple{treatment_names}(
[[scm.graph.vertex_labels[w] for w in confounders_codes[i]]
for i in eachindex(confounders_codes)]
)
treatment_confounders = Dict(treatment_names[i] => [scm.graph.vertex_labels[w] for w in confounders_codes[i]] for i in eachindex(confounders_codes))

return statistical_type_from_causal_type(T)(;
outcome=causal_estimand.outcome,
Expand All @@ -240,15 +228,15 @@ We ensure that the values are sorted by frequency to maximize
the number of estimands passing the positivity constraint.
"""
unique_treatment_values(dataset, colnames) =
(;(colname => get_treatment_values(dataset, colname) for colname in colnames)...)
sort(OrderedDict(colname => get_treatment_values(dataset, colname) for colname in colnames))

"""
Generated from transitive treatment switches to create independent estimands.
"""
get_treatment_settings(::Union{typeof(ATE), typeof(IATE)}, treatments_unique_values::NamedTuple{names}) where names =
NamedTuple{names}([collect(zip(vals[1:end-1], vals[2:end])) for vals in values(treatments_unique_values)])
get_treatment_settings(::Union{typeof(ATE), typeof(IATE)}, treatments_unique_values)=
sort(OrderedDict(key => collect(zip(uniquevaluess[1:end-1], uniquevaluess[2:end])) for (key, uniquevaluess) in pairs(treatments_unique_values)))

get_treatment_settings(::typeof(CM), treatments_unique_values) = treatments_unique_values
get_treatment_settings(::typeof(CM), treatments_unique_values) = sort(OrderedDict(pairs(treatments_unique_values)))

get_treatment_setting(combo::Tuple{Vararg{Tuple}}) = [NamedTuple{(:control, :case)}(treatment_control_case) for treatment_control_case combo]

Expand All @@ -257,7 +245,7 @@ get_treatment_setting(combo) = collect(combo)
"""
If there is no dataset and the treatments_levels are a NamedTuple, then they are assumed correct.
"""
make_or_check_treatment_levels(treatments_levels::NamedTuple, dataset::Nothing) = treatments_levels
make_or_check_treatment_levels(treatments_levels::Union{AbstractDict, NamedTuple}, dataset::Nothing) = treatments_levels

"""
If no dataset is provided, then a NamedTuple precising treatment levels is expected
Expand All @@ -273,7 +261,7 @@ make_or_check_treatment_levels(treatments, dataset) = unique_treatment_values(da
"""
If a NamedTuple of treatments_levels is provided as well as a dataset then the treatment_levels are checked from the dataset.
"""
function make_or_check_treatment_levels(treatments_levels::NamedTuple, dataset)
function make_or_check_treatment_levels(treatments_levels::Union{AbstractDict, NamedTuple}, dataset)
for (treatment, treatment_levels) in zip(keys(treatments_levels), values(treatments_levels))
dataset_treatment_levels = Set(skipmissing(Tables.getcolumn(dataset, treatment)))
missing_levels = setdiff(treatment_levels, dataset_treatment_levels)
Expand All @@ -285,18 +273,20 @@ end

function _factorialEstimand(
constructor,
treatments_settings::NamedTuple{names}, outcome;
treatments_settings,
outcome;
confounders=nothing,
outcome_extra_covariates=nothing,
freq_table=nothing,
positivity_constraint=nothing,
verbosity=1
) where names
)
names = keys(treatments_settings)
components = []
for combo Iterators.product(values(treatments_settings)...)
Ψ = constructor(
outcome=outcome,
treatment_values=NamedTuple{names}(get_treatment_setting(combo)),
treatment_values=OrderedDict(zip(names, get_treatment_setting(combo))),
treatment_confounders = confounders,
outcome_extra_covariates=outcome_extra_covariates
)
Expand Down Expand Up @@ -343,7 +333,7 @@ A `JointEstimand` with causal or statistical components.
# Args
- `constructor`: CM, ATE or IATE.
- `treatments`: A NamedTuple of treatment levels (e.g. `(T=(0, 1, 2),)`) or a treatment iterator, then a dataset must be provided to infer the levels from it.
- `treatments`: An AbstractDictionary/NamedTuple of treatment levels (e.g. `(T=(0, 1, 2),)`) or a treatment iterator, then a dataset must be provided to infer the levels from it.
- `outcome`: The outcome variable.
- `confounders=nothing`: The generated components will inherit these confounding variables. If `nothing`, causal estimands are generated.
- `outcome_extra_covariates=()`: The generated components will inherit these `outcome_extra_covariates`.
Expand Down Expand Up @@ -450,6 +440,6 @@ end
joint_levels::StatisticalIATE) = Iterators.product(values.treatment_values)...)

joint_levels::StatisticalATE) =
(Tuple.treatment_values[T][c] for T keys.treatment_values)) for c in (:case, :control))
(Tuple.treatment_values[T][c] for T keys.treatment_values)) for c in (:control, :case))

joint_levels::StatisticalCM) = (values.treatment_values),)
joint_levels::StatisticalCM) = (Tuple(values.treatment_values)),)
12 changes: 6 additions & 6 deletions src/estimands.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ end
Makes sure the defined treatment levels are present in the dataset.
"""
function check_treatment_levels::Estimand, dataset)
for treatment_name in treatments(Ψ)
treatment_levels = levels(Tables.getcolumn(dataset, treatment_name))
treatment_settings = getproperty(Ψ.treatment_values, treatment_name)
check_treatment_settings(treatment_settings, treatment_levels, treatment_name)
for T in treatments(Ψ)
treatment_levels = levels(Tables.getcolumn(dataset, T))
treatment_settings = Ψ.treatment_values[T]
check_treatment_settings(treatment_settings, treatment_levels, T)
end
end

Expand Down Expand Up @@ -106,7 +106,7 @@ const ExpectedValue = ConditionalDistribution
### JointEstimand ###
#####################################################################

struct JointEstimand <: Estimand
@auto_hash_equals struct JointEstimand <: Estimand
args::Tuple
JointEstimand(args...) = new(Tuple(args))
end
Expand Down Expand Up @@ -144,7 +144,7 @@ end
### Composed Estimand ###
#####################################################################

struct ComposedEstimand <: Estimand
@auto_hash_equals struct ComposedEstimand <: Estimand
f::Function
estimand::JointEstimand
end
Expand Down
2 changes: 1 addition & 1 deletion src/scm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Base.show(io::IO, ::MIME"text/plain", scm::SCM) =
println(io, string_repr(scm))

split_outcome_parent_pair(outcome_parents_pair::Pair) = outcome_parents_pair
split_outcome_parent_pair(outcome_parents_pair::Dict{T, Any}) where T = outcome_parents_pair[T(:outcome)], outcome_parents_pair[T(:parents)]
split_outcome_parent_pair(outcome_parents_pair::AbstractDict{T, Any}) where T = outcome_parents_pair[T(:outcome)], outcome_parents_pair[T(:parents)]

function add_equations!(scm::SCM, equations...)
for outcome_parents_pair in equations
Expand Down
8 changes: 4 additions & 4 deletions test/adjustment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ using TMLE
@test statistical_estimand.treatment_values == causal_estimand.treatment_values
@test statistical_estimand.outcome_extra_covariates == (:C,)
end
@test statistical_estimands[1].treatment_confounders == (T₁=(:W₁, :W₂),)
@test statistical_estimands[2].treatment_confounders == (T₁=(:W₁, :W₂),)
@test statistical_estimands[3].treatment_confounders == (T₁=(:W₁, :W₂), T₂=(:W₁, :W₂))
@test statistical_estimands[4].treatment_confounders == (T₁=(:W₁, :W₂), T₂=(:W₁, :W₂))
@test statistical_estimands[1].treatment_confounders == Dict(:T₁ => (:W₁, :W₂),)
@test statistical_estimands[2].treatment_confounders == Dict(:T₁ => (:W₁, :W₂),)
@test statistical_estimands[3].treatment_confounders == Dict(:T₁ => (:W₁, :W₂), :T₂ => (:W₁, :W₂))
@test statistical_estimands[4].treatment_confounders == Dict(:T₁ => (:W₁, :W₂), :T₂ => (:W₁, :W₂))
end

@testset "Test TMLE.to_dict" begin
Expand Down
Loading

0 comments on commit dbb49bd

Please sign in to comment.