Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce plot data collection for different local graph confusion statistics #1776

Merged
merged 4 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 141 additions & 11 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,17 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(Ind
return accepts_rejects;
}

/**
* Get accepts and rejects nodes for all nodes from Anderson-Darling test and generate the plot data for confusion statistics.
*
* Confusion statistics were calculated using Adjacency (AdjacencyPrecision, AdjacencyRecall) and Arrowhead (ArrowheadPrecision, ArrowheadRecall)
* @param independenceTest
* @param estimatedCpdag
* @param trueGraph
* @param threshold
* @param shuffleThreshold
* @return
*/
public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData(IndependenceTest independenceTest, Graph estimatedCpdag, Graph trueGraph, Double threshold, Double shuffleThreshold) {
// When calling, default reject null as <=0.05
List<List<Node>> accepts_rejects = new ArrayList<>();
Expand Down Expand Up @@ -381,35 +392,35 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot
List<List<Double>> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); // shuffleThreshold default to be 0.5
for (List<Double> localPValues: shuffledlocalPValues) {
// P value obtained from AD test
Double ADTest = checkAgainstAndersonDarlingTest(localPValues);
Double ADTestPValue = checkAgainstAndersonDarlingTest(localPValues);
// TODO VBC: what should we do for cases when ADTest is NaN and ∞ ?
if (ADTest <= threshold) {
if (ADTestPValue <= threshold) {
rejects.add(x);
if (!Double.isNaN(ap)) {
rejects_AdjP_ADTestP.add(Arrays.asList(ap, ADTest));
rejects_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ar)) {
rejects_AdjR_ADTestP.add(Arrays.asList(ap, ADTest));
rejects_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahp)) {
rejects_AHP_ADTestP.add(Arrays.asList(ap, ADTest));
rejects_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahr)) {
rejects_AHR_ADTestP.add(Arrays.asList(ap, ADTest));
rejects_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
} else {
accepts.add(x);
if (!Double.isNaN(ap)) {
accepts_AdjP_ADTestP.add(Arrays.asList(ap, ADTest));
accepts_AdjP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ar)) {
accepts_AdjR_ADTestP.add(Arrays.asList(ap, ADTest));
accepts_AdjR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahp)) {
accepts_AHP_ADTestP.add(Arrays.asList(ap, ADTest));
accepts_AHP_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
if (!Double.isNaN(ahr)) {
accepts_AHR_ADTestP.add(Arrays.asList(ap, ADTest));
accepts_AHR_ADTestP.add(Arrays.asList(ap, ADTestPValue));
}
}
}
Expand All @@ -421,7 +432,7 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot
try (BufferedWriter writer = new BufferedWriter(new FileWriter(entry.getKey()))) {
writer.write(entry.getValue());
switch (entry.getKey()) {
case "acceptsAdjP_ADTestP_data.csv":
case "accepts_AdjP_ADTestP_data.csv":
for (List<Double> AdjP_ADTestP_pair : accepts_AdjP_ADTestP) {
writer.write(nf.format(AdjP_ADTestP_pair.get(0)) + "," + nf.format(AdjP_ADTestP_pair.get(1)) + "\n");
}
Expand Down Expand Up @@ -479,6 +490,112 @@ public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlot
return accepts_rejects;
}

/**
* Get accepts and rejects nodes for all nodes from Anderson-Darling test and generate the plot data for confusion statistics.
*
* Confusion statistics were calculated using Local Graph Precision and Recall (LocalGraphPrecision, LocalGraphRecall).
* @param independenceTest
* @param estimatedCpdag
* @param trueGraph
* @param threshold
* @param shuffleThreshold
* @return
*/
public List<List<Node>> getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData2(IndependenceTest independenceTest, Graph estimatedCpdag, Graph trueGraph, Double threshold, Double shuffleThreshold) {
// When calling, default reject null as <=0.05
List<List<Node>> accepts_rejects = new ArrayList<>();
List<Node> accepts = new ArrayList<>();
List<Node> rejects = new ArrayList<>();
List<Node> allNodes = graph.getNodes();

// Confusion stats lists for data processing.
Map<String, String> fileContentMap = new HashMap<>();

// Using Local Graph Precision and Recall to calculate Confusion statistics.
List<List<Double>> accepts_LGP_ADTestP = new ArrayList<>();
List<List<Double>> accepts_LGR_ADTestP = new ArrayList<>();
fileContentMap.put("accepts_LGP_ADTestP_data.csv", "");
fileContentMap.put("accepts_LGR_ADTestP_data.csv", "");

List<List<Double>> rejects_LGP_ADTestP = new ArrayList<>();
List<List<Double>> rejects_LGR_ADTestP = new ArrayList<>();
fileContentMap.put("rejects_LGP_ADTestP_data.csv", "");
fileContentMap.put("rejects_LGR_ADTestP_data.csv", "");

NumberFormat nf = new DecimalFormat("0.00");
// Classify nodes into accepts and rejects base on ADTest result, and update confusion stats lists accordingly.
for (Node x : allNodes) {
List<IndependenceFact> localIndependenceFacts = getLocalIndependenceFacts(x);
List<Double> lgp_lgr = getPrecisionAndRecallOnMarkovBlanketGraphPlotData2(x, estimatedCpdag, trueGraph);
Double lgp = lgp_lgr.get(0);
Double lgr = lgp_lgr.get(1);
// All local nodes' p-values for node x.
List<List<Double>> shuffledlocalPValues = getLocalPValues(independenceTest, localIndependenceFacts, shuffleThreshold); // shuffleThreshold default to be 0.5
for (List<Double> localPValues: shuffledlocalPValues) {
// P value obtained from AD test
Double ADTestPValue = checkAgainstAndersonDarlingTest(localPValues);
// TODO VBC: what should we do for cases when ADTest is NaN and ∞ ?
if (ADTestPValue <= threshold) {
rejects.add(x);
if (!Double.isNaN(lgp)) {
rejects_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue));
}
if (!Double.isNaN(lgr)) {
rejects_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue));
}
} else {
accepts.add(x);
if (!Double.isNaN(lgp)) {
accepts_LGP_ADTestP.add(Arrays.asList(lgp, ADTestPValue));
}
if (!Double.isNaN(lgr)) {
accepts_LGR_ADTestP.add(Arrays.asList(lgr, ADTestPValue));
}
}
}
}
accepts_rejects.add(accepts);
accepts_rejects.add(rejects);
// Write into data files.
for (Map.Entry<String, String> entry : fileContentMap.entrySet()) {
try (BufferedWriter writer = new BufferedWriter(new FileWriter(entry.getKey()))) {
writer.write(entry.getValue());
switch (entry.getKey()) {
case "accepts_LGP_ADTestP_data.csv":
for (List<Double> LGP_ADTestP_pair : accepts_LGP_ADTestP) {
writer.write(nf.format(LGP_ADTestP_pair.get(0)) + "," + nf.format(LGP_ADTestP_pair.get(1)) + "\n");
}
break;

case "accepts_LGR_ADTestP_data.csv":
for (List<Double> LGR_ADTestP_pair : accepts_LGR_ADTestP) {
writer.write(nf.format(LGR_ADTestP_pair.get(0)) + "," + nf.format(LGR_ADTestP_pair.get(1)) + "\n");
}
break;

case "rejects_LGP_ADTestP_data.csv":
for (List<Double> LGP_ADTestP_pair : rejects_LGP_ADTestP) {
writer.write(nf.format(LGP_ADTestP_pair.get(0)) + "," + nf.format(LGP_ADTestP_pair.get(1)) + "\n");
}
break;

case "rejects_LGR_ADTestP_data.csv":
for (List<Double> LGR_ADTestP_pair : rejects_LGR_ADTestP) {
writer.write(nf.format(LGR_ADTestP_pair.get(0)) + "," + nf.format(LGR_ADTestP_pair.get(1)) + "\n");
}
break;

default:
break;
}
System.out.println("Successfully written to " + entry.getKey());
} catch (IOException e) {
e.printStackTrace();
}
}
return accepts_rejects;
}

/**
* Calculates the precision and recall on the Markov Blanket graph for a given node. Prints the statistics to the
* console.
Expand Down Expand Up @@ -547,6 +664,19 @@ public void getPrecisionAndRecallOnMarkovBlanketGraph2(Node x, Graph estimatedGr
" LocalGraphPrecision = " + nf.format(lgp) + " LocalGraphRecall = " + nf.format(lgr) + " \n");
}

public List<Double> getPrecisionAndRecallOnMarkovBlanketGraphPlotData2(Node x, Graph estimatedGraph, Graph trueGraph) {
// Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes.
Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes());
Graph xMBLookupGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(lookupGraph, x);
System.out.println("xMBLookupGraph:" + xMBLookupGraph);
Graph xMBEstimatedGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(estimatedGraph, x);
System.out.println("xMBEstimatedGraph:" + xMBEstimatedGraph);

double lgp = new LocalGraphPrecision().getValue(xMBLookupGraph, xMBEstimatedGraph, null);
double lgr = new LocalGraphRecall().getValue(xMBLookupGraph, xMBEstimatedGraph, null);
return Arrays.asList(lgp, lgr);
}

/**
* Returns the variables of the independence test.
*
Expand Down
20 changes: 4 additions & 16 deletions tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ public void testNonGaussianCPDAGPrecisionRecallForLocalOnParents() {
}

@Test
public void testDAGPrecisionRecall2ForLocalOnMarkovBlanket() {
public void testGaussianDAGPrecisionRecallForLocalOnMarkovBlanket2() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());
Expand All @@ -461,25 +461,13 @@ public void testDAGPrecisionRecall2ForLocalOnMarkovBlanket() {
IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET);
// ADTest pass/fail threshold default to be 0.05. shuffleThreshold default to be 0.5
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5);
// List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05, 0.5);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodesPlotData2(fisherZTest, estimatedCpdag, trueGraph, 0.05, 0.3);

List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());

List<Double> acceptsPrecision = new ArrayList<>();
List<Double> acceptsRecall = new ArrayList<>();
for(Node a: accepts) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph);
System.out.println("=====================");

}
for (Node a: rejects) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph);
System.out.println("=====================");
}
}

}