Skip to content

Commit

Permalink
Merge pull request #1770 from cmu-phil/vbc-05-16
Browse files Browse the repository at this point in the history
Include Non Gaussian cases for Local Precision and Recall tests for DAG and CPDAG respectively
  • Loading branch information
jdramsey authored May 16, 2024
2 parents 8477914 + 12cf18e commit ea4507d
Showing 1 changed file with 188 additions and 4 deletions.
192 changes: 188 additions & 4 deletions tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.Params;
import org.junit.Test;

import java.util.ArrayList;
Expand Down Expand Up @@ -111,12 +112,13 @@ public void test2() {
}

@Test
public void testDAGPrecisionRecallForLocalOnMarkovBlanket() {
public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

SemPm pm = new SemPm(trueGraph);
// Parameters without additional setting default tobe Gaussian
SemIm im = new SemIm(pm, new Parameters());
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
Expand Down Expand Up @@ -149,14 +151,15 @@ public void testDAGPrecisionRecallForLocalOnMarkovBlanket() {
}

@Test
public void testCPDAGPrecisionRecallForLocalOnMarkovBlanket() {
public void testGaussianCPDAGPrecisionRecallForLocalOnMarkovBlanket() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
// The completed partially directed acyclic graph (CPDAG) for the given DAG.
Graph trueGraphCPDAG = GraphTransforms.dagToCpdag(trueGraph);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph CPDAG: " + trueGraphCPDAG);

SemPm pm = new SemPm(trueGraph);
// Parameters without additional setting default tobe Gaussian
SemIm im = new SemIm(pm, new Parameters());
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
Expand Down Expand Up @@ -188,12 +191,104 @@ public void testCPDAGPrecisionRecallForLocalOnMarkovBlanket() {
}

@Test
public void testDAGPrecisionRecallForLocalOnParents() {
public void testNonGaussianDAGPrecisionRecallForLocalOnMarkovBlanket() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

SemPm pm = new SemPm(trueGraph);

Parameters params = new Parameters();
// Manually set non-Gaussian
params.set(Params.SIMULATION_ERROR_TYPE, 3);
params.set(Params.SIMULATION_PARAM1, 1);

SemIm im = new SemIm(pm, params);
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
score.setPenaltyDiscount(2);
Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search();
System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag);
System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~");

IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05);
List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());

List<Double> acceptsPrecision = new ArrayList<>();
List<Double> acceptsRecall = new ArrayList<>();
for(Node a: accepts) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph);
System.out.println("=====================");

}
for (Node a: rejects) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph);
System.out.println("=====================");
}
}

@Test
public void testNonGaussianCPDAGPrecisionRecallForLocalOnMarkovBlanket() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
// The completed partially directed acyclic graph (CPDAG) for the given DAG.
Graph trueGraphCPDAG = GraphTransforms.dagToCpdag(trueGraph);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph CPDAG: " + trueGraphCPDAG);

SemPm pm = new SemPm(trueGraph);

Parameters params = new Parameters();
// Manually set non-Gaussian
params.set(Params.SIMULATION_ERROR_TYPE, 3);
params.set(Params.SIMULATION_PARAM1, 1);

SemIm im = new SemIm(pm, params);
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
score.setPenaltyDiscount(2);
Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search();
System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag);
System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~");

IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05);
List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());

// Compare the Est CPDAG with True graph's CPDAG.
for(Node a: accepts) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraphCPDAG);
System.out.println("=====================");

}
for (Node a: rejects) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraphCPDAG);
System.out.println("=====================");
}
}



@Test
public void testGaussianDAGPrecisionRecallForLocalOnParents() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

SemPm pm = new SemPm(trueGraph);
// Parameters without additional setting default tobe Gaussian
SemIm im = new SemIm(pm, new Parameters());
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
Expand Down Expand Up @@ -225,14 +320,15 @@ public void testDAGPrecisionRecallForLocalOnParents() {
}

@Test
public void testCPDAGPrecisionRecallForLocalOnParents() {
public void testGaussianCPDAGPrecisionRecallForLocalOnParents() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
// The completed partially directed acyclic graph (CPDAG) for the given DAG.
Graph trueGraphCPDAG = GraphTransforms.dagToCpdag(trueGraph);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph CPDAG: " + trueGraphCPDAG);

SemPm pm = new SemPm(trueGraph);
// Parameters without additional setting default tobe Gaussian
SemIm im = new SemIm(pm, new Parameters());
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
Expand Down Expand Up @@ -262,4 +358,92 @@ public void testCPDAGPrecisionRecallForLocalOnParents() {
System.out.println("=====================");
}
}

@Test
public void testNonGaussianDAGPrecisionRecallForLocalOnParents() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

SemPm pm = new SemPm(trueGraph);
Parameters params = new Parameters();
// Manually set non-Gaussian
params.set(Params.SIMULATION_ERROR_TYPE, 3);
params.set(Params.SIMULATION_PARAM1, 1);

SemIm im = new SemIm(pm, params);
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
score.setPenaltyDiscount(2);
Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search();
System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag);
System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~");

IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
// TODO VBC: confirm on the choice of ConditioningSetType.
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.LOCAL_MARKOV);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05);
List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());

for(Node a: accepts) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph);
System.out.println("=====================");

}
for (Node a: rejects) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraph);
System.out.println("=====================");
}
}

@Test
public void testNonGaussianCPDAGPrecisionRecallForLocalOnParents() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
// The completed partially directed acyclic graph (CPDAG) for the given DAG.
Graph trueGraphCPDAG = GraphTransforms.dagToCpdag(trueGraph);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph CPDAG: " + trueGraphCPDAG);

SemPm pm = new SemPm(trueGraph);

Parameters params = new Parameters();
// Manually set non-Gaussian
params.set(Params.SIMULATION_ERROR_TYPE, 3);
params.set(Params.SIMULATION_PARAM1, 1);

SemIm im = new SemIm(pm, params);
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
score.setPenaltyDiscount(2);
Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search();
System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag);
System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~");

IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05);
List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());

// Compare the Est CPDAG with True graph's CPDAG.
for(Node a: accepts) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraphCPDAG);
System.out.println("=====================");

}
for (Node a: rejects) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph(a, estimatedCpdag, trueGraphCPDAG);
System.out.println("=====================");
}
}

}

0 comments on commit ea4507d

Please sign in to comment.