diff --git a/src/Utils/data.jl b/src/Utils/data.jl index 9c21fd2e..15fcbdab 100644 --- a/src/Utils/data.jl +++ b/src/Utils/data.jl @@ -283,6 +283,7 @@ end "Turn binary data into floating point data close to 0 and 1." soften(data::DataFrame, softness=0.05; scale_by_marginal=true, precision=Float32) = begin + @assert !isfpdata(data) "'soften' does not support floating point data." n_col = ncol(data) data_weighted = isweighted(data) @@ -306,10 +307,10 @@ end "Compute the marginal prob of each feature in a binary dataset." marginal_prob(data; precision=Float32) = begin - @assert isbinarydata(data) "marginal_prob only support binary data." + @assert !isfpdata(data) "'marginal_prob' does not support floating point data." n_examples = num_examples(data) map(1:num_features(data)) do idx - precision(sum(feature_values(data, idx))) / n_examples + precision(sum(coalesce.(feature_values(data, idx), 0.5))) / n_examples end end diff --git a/src/queries/queries.jl b/src/queries/queries.jl index 3316c1f3..a9482268 100644 --- a/src/queries/queries.jl +++ b/src/queries/queries.jl @@ -144,21 +144,8 @@ end Is the circuit structured-decomposable? """ -function isstruct_decomposable(root::LogicCircuit, cache=nothing)::Bool - # WARNING: this function is known to have bugs; https://github.com/Juice-jl/LogicCircuits.jl/issues/82 - result::Bool = true - f_con(_) = [BitSet()] - f_lit(n) = [BitSet(variable(n))] - f_a(_, cs) = begin - result = result && isdisjoint(vcat(cs...)...) - map(c -> reduce(union!, c), cs) - end - f_o(_, cs) = begin - result = result && (length(cs) == 0 || all(==(cs[1]), cs)) - [reduce(union, vcat(cs...))] - end - foldup_aggregate(root, f_con, f_lit, f_a, f_o, Vector{BitSet}, cache) - result +function isstruct_decomposable(root::LogicCircuit)::Bool + infer_vtree(root) !== nothing end @@ -174,26 +161,41 @@ function infer_vtree(root::LogicCircuit, cache=nothing)::Union{Vtree, Nothing} throw("Circuit not smooth. Inferring vtree not supported yet!") end - if !isstruct_decomposable(root) - return nothing # or should we throw error? - end + vtr_dict = Dict{Var, PlainVtreeLeafNode}() - f_con(_) = nothing # can we have constants when there is a vtree? + f_con(_) = (nothing, true) # give constants nothing f_lit(n) = begin - PlainVtree(variable(n)) + if variable(n) ∉ keys(vtr_dict) + vtr_dict[variable(n)] = PlainVtree(variable(n)) + end + (vtr_dict[variable(n)], true) end f_a(n, call) = begin - @assert num_children(n) == 2 "And node had $num_children(n) childern. Should be 2" - left = call(children(n)[1])::Vtree - right = call(children(n)[2])::Vtree - PlainVtreeInnerNode(left, right) + if num_children(n) != 2 + return (call(children(n)[1])[1], false) + end + + (left, ld) = call(children(n)[1]) + (right, rd) = call(children(n)[2]) + if ((left === nothing) | (right === nothing)) & (left !== right) + (left === nothing ? right : left, true) + elseif !isdisjoint(variables(left), variables(right)) + (left, false) + elseif !has_parent(left) & !has_parent(right) + (PlainVtreeInnerNode(left, right), ld & rd) + elseif has_parent(left) & has_parent(right) & (parent(left) == parent(right)) + (parent(left), ld & rd) + else + (left, false) + end end f_o(n, call) = begin - # Already checked struct-decomposable so just expand on first child @assert num_children(n) > 0 "Or node has no children" - call(children(n)[1]) - end - foldup(root, f_con, f_lit, f_a, f_o, Vtree, cache) + ccalls = map(x -> call(x), children(n)) + (ccalls[1][1], all(x -> x[1] == ccalls[1][1], ccalls) & all(x -> x[2], ccalls)) + end + res = foldup(root, f_con, f_lit, f_a, f_o, Tuple{Union{Vtree, Nothing}, Bool}) + res[2] ? res[1] : nothing end diff --git a/test/Utils/data_test.jl b/test/Utils/data_test.jl index 334a9201..22f5d74d 100644 --- a/test/Utils/data_test.jl +++ b/test/Utils/data_test.jl @@ -118,5 +118,17 @@ using CUDA: CUDA @test isgpu(dfb_gpu_split1) @test isgpu(to_gpu(wdfb1_gpu)) end + + df = DataFrame(Matrix{Union{Bool,Missing}}([true missing; missing false])) + sdf = soften(df, 0.01; scale_by_marginal = false) + @test sdf[1,1] ≈ 0.99 atol = 1e-6 + @test ismissing(sdf[1,2]) + @test ismissing(sdf[2,1]) + @test sdf[2,2] ≈ 0.01 atol = 1e-6 + sdf = soften(df, 0.01; scale_by_marginal = true) + @test sdf[1,1] ≈ 0.9975 atol = 1e-6 + @test ismissing(sdf[1,2]) + @test ismissing(sdf[2,1]) + @test sdf[2,2] ≈ 0.0025 atol = 1e-6 end diff --git a/test/queries/queries_test.jl b/test/queries/queries_test.jl index cc3afd19..6f2fb111 100644 --- a/test/queries/queries_test.jl +++ b/test/queries/queries_test.jl @@ -15,7 +15,7 @@ include("../helper/plain_logic_circuits.jl") @test isdeterministic(r1) @test isdeterministic(compile(PlainLogicCircuit, Lit(1))) - @test isstruct_decomposable(r1) + @test !isstruct_decomposable(r1) @test isstruct_decomposable(compile(PlainLogicCircuit, Lit(1))) @test variables(r1) == BitSet(1:10) @@ -40,7 +40,7 @@ include("../helper/plain_logic_circuits.jl") or1 = and1 | and2 @test !isdecomposable(or1) @test !isdeterministic(or1) - @test !isstruct_decomposable(or1) + # @test !isstruct_decomposable(or1) ####################### ors = map(1:10) do v @@ -57,7 +57,7 @@ include("../helper/plain_logic_circuits.jl") and5 = or1 & or2 @test !isdecomposable(and5) @test !isdeterministic(and5) - @test !isstruct_decomposable(and5) + # @test !isstruct_decomposable(and5) ####################### leaf1 = compile(PlainLogicCircuit, Lit(1)) @@ -93,7 +93,7 @@ include("../helper/plain_logic_circuits.jl") @test isstruct_decomposable(and2) @test isdecomposable(circuit) @test !isstruct_decomposable(circuit) - @test !isstruct_decomposable(or1 & lits[3] | lits[3] & or1) + @test isstruct_decomposable(or1 & lits[3] | lits[3] & or1) @test isstruct_decomposable(or1 & lits[3] | or1 & lits[3]) end