Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing LocalGraphConfusion and its corresponding Precision and Recall classes #1771

Merged
merged 2 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package edu.cmu.tetrad.algcomparison.statistic;

import edu.cmu.tetrad.algcomparison.statistic.utils.LocalGraphConfusion;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.graph.Graph;

public class LocalGraphPrecision implements Statistic {
@Override
public String getAbbreviation() {
return "LGP";
}

@Override
public String getDescription() {
return "Local Graph Precision";
}

@Override
public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) {
LocalGraphConfusion lgConfusion = new LocalGraphConfusion(trueGraph, estGraph);
int lgTp = lgConfusion.getTp();
int lgFp = lgConfusion.getFp();
return lgTp / (double) (lgTp + lgFp);
}

@Override
public double getNormValue(double value) {
return value;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package edu.cmu.tetrad.algcomparison.statistic;

import edu.cmu.tetrad.algcomparison.statistic.utils.LocalGraphConfusion;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.graph.Graph;

public class LocalGraphRecall implements Statistic {
@Override
public String getAbbreviation() {
return "LGR";
}

@Override
public String getDescription() {
return "Local Graph Recall";
}

@Override
public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) {
LocalGraphConfusion lgConfusion = new LocalGraphConfusion(trueGraph, estGraph);
int lgTp = lgConfusion.getTp();
int lgFn = lgConfusion.getFn();
return lgTp / (double) (lgTp + lgFn);
}

@Override
public double getNormValue(double value) {
return value;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
package edu.cmu.tetrad.algcomparison.statistic.utils;

import edu.cmu.tetrad.graph.*;

import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
* A confusion matrix for local graph accuracy check --i.e. TP, FP, TN, FN for counts of a combination of
* arrowhead and precision.
*/
public class LocalGraphConfusion {
/**
* The true positive count.
*/
private int tp;

/**
* The true negative count.
*/
private int tn;

/**
* The false positive count.
*/
private int fp;

/**
* The false positive count.
*/
private int fn;

/**
* Constructs a new LocalGraphConfusion object from the given graphs.
* @param trueGraph The true graph
*
* @param estGraph The estimated graph
*/
public LocalGraphConfusion(Graph trueGraph, Graph estGraph) {
this.tp = 0;
this.tn = 0;
this.fp = 0;
this.fn = 0;

// STEP0: Create lookups for both true graph and estimated graph.
// trueGraphLookup is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes.
Graph trueGraphLookup = GraphUtils.replaceNodes(trueGraph, estGraph.getNodes());
// estGraphLookup is the same structure as estGraph's structure but node objects replaced by true graph nodes.
Graph estGraphLookup = GraphUtils.replaceNodes(estGraph, trueGraph.getNodes());

// STEP1: Check for Adjacency.
/**
* True
* Y N
* ---------------------
* Y | TP FP
* Est | --------------------
* N | FN TN
* -----------------------
*/
// STEP 1.1: Create allUnoriented base on trueGraphLookup and estimatedGraph
Set<Edge> allUnoriented = new HashSet<>();
for (Edge edge: trueGraphLookup.getEdges()) {
allUnoriented.add(Edges.undirectedEdge(edge.getNode1(), edge.getNode2()));
}
for (Edge edge: estGraph.getEdges()) {
allUnoriented.add(Edges.undirectedEdge(edge.getNode1(), edge.getNode2()));
}
// STEP 1.2: Iterate through allUnoriented to record confusion metrix
for (Edge u: allUnoriented) {
Node node1 = u.getNode1();
Node node2 = u.getNode2();
if (estGraph.isAdjacentTo(node1, node2)) { // Est: Y
if (trueGraphLookup.isAdjacentTo(node1, node2)) { // True: Y
this.tp++;
} else { // True: N
this.fp++;
}
} else { // Est: N
if (trueGraphLookup.isAdjacentTo(node1, node2)) { // True: Y
this.fn++;
} else { // True: N
this.tn++;
}
}
}

// STEP2: Check for Orientation(i.e. Arrowhead), so we need to check both endpoints of an edge.
/**
* True
* -> <- ...(None)
* ---------------------------
* -> | TP FP,FN / (Do not repeat count, as we checked for it in Adj step)
* Est | --------------------------
* <- | FP, FN TP /
* | --------------------------
* -- | FN FN /
* | --------------------------
* ...| / / /
* -----------------------------
*
*/
// STEP2.1: Check through the true graph
for (Edge tle: trueGraphLookup.getEdges()) {
// STEP2.1.1: Get corresponding endpoint in Est graph lookup
List<Edge> estGraphLookupEdges = estGraphLookup.getEdges(tle.getNode1(), tle.getNode2());
Edge ele; // estimated lookup graph edge
if (estGraphLookupEdges.size() == 1) {
ele = estGraphLookupEdges.iterator().next();
} else {
ele = estGraphLookup.getDirectedEdge(tle.getNode1(), tle.getNode2());
}
Endpoint ep1Est = null;
Endpoint ep2Est = null;
if (ele != null) {
ep1Est = ele.getProximalEndpoint(tle.getNode1());
ep2Est = ele.getProximalEndpoint(tle.getNode2());
}

// STEP2.1.2: Get corresponding endpoint in true graph lookup
List<Edge> trueGraphLookupEdges = trueGraphLookup.getEdges(tle.getNode1(), tle.getNode1());
Edge tle2;
if (trueGraphLookupEdges.size() == 1) {
tle2 = trueGraphLookupEdges.iterator().next();
} else {
tle2 = trueGraphLookup.getDirectedEdge(tle.getNode1(), tle.getNode2());
}
Endpoint ep1True = null;
Endpoint ep2True = null;
if (tle2 != null) {
ep1True = tle2.getProximalEndpoint(tle.getNode1());
ep2True = tle2.getProximalEndpoint(tle.getNode2());
}

// STEP2.1.3: Compare the endpoints
// we only care the case when the edge exist.
boolean connected = trueGraph.isAdjacentTo(tle.getNode1(), tle.getNode2())
&& estGraph.isAdjacentTo(tle.getNode1(), tle.getNode2());
if (connected) {
if (ep1True == Endpoint.TAIL && ep2True == Endpoint.ARROW) { // True: ->
if (ep1Est == Endpoint.TAIL && ep2Est == Endpoint.ARROW) { // Est: ->
this.tp++;
} else if (ep1Est == Endpoint.ARROW && ep2Est == Endpoint.TAIL) { // Est: <-
// this.fp++;
this.fn++;
} else if (ep1Est == Endpoint.TAIL && ep2Est == Endpoint.TAIL) { // Est: --
this.fn++;
}
} else if (ep1True == Endpoint.ARROW && ep2True == Endpoint.TAIL) { // True: <-
if (ep1Est == Endpoint.TAIL && ep2Est == Endpoint.ARROW) { // Est: ->
// this.fp++;
this.fn++;
} else if (ep1Est == Endpoint.ARROW && ep2Est == Endpoint.TAIL) { // Est: <-
this.tp++;
} else if (ep1Est == Endpoint.TAIL && ep2Est == Endpoint.TAIL) { // Est: --
this.fn++;
}
}
}
}
// STEP2: Check through the est graph
// because est graph can have extra arrowhead that was not in true graph, which should be count as fp.
for (Edge ele: estGraphLookup.getEdges()) {
List<Edge> estGraphLookupEdges = estGraphLookup.getEdges(ele.getNode1(), ele.getNode2());
Edge ele2;
if (estGraphLookupEdges.size() == 1) {
ele2 = estGraphLookupEdges.iterator().next();
} else {
ele2 = estGraphLookup.getDirectedEdge(ele.getNode1(), ele.getNode2());
}
Endpoint ep1Est = null;
Endpoint ep2Est = null;
if (ele2 != null) {
ep1Est = ele2.getProximalEndpoint(ele.getNode1());
ep2Est = ele2.getProximalEndpoint(ele.getNode2());
}

List<Edge> trueGraphLookupEdges = trueGraphLookup.getEdges(ele.getNode1(), ele.getNode1());
Edge tle;
if (trueGraphLookupEdges.size() == 1) {
tle = trueGraphLookupEdges.iterator().next();
} else {
tle = trueGraphLookup.getDirectedEdge(ele.getNode1(), ele.getNode2());
}
Endpoint ep1True = null;
Endpoint ep2True = null;
if (tle != null) {
ep1True = tle.getProximalEndpoint(ele.getNode1());
ep2True = tle.getProximalEndpoint(ele.getNode2());
}

boolean connected = trueGraph.isAdjacentTo(ele.getNode1(), ele.getNode2());
if (connected) {
if (ep1True == Endpoint.TAIL && ep2True == Endpoint.ARROW) { // True: ->
if (ep1Est == Endpoint.ARROW && ep2Est == Endpoint.TAIL) { // Est: <-
this.fp++;
}
// TODO VBC: Question: seems we wont encounter <-> case, is it?
} else if (ep1True == Endpoint.ARROW && ep2True == Endpoint.TAIL) { // True: <-
if (ep1Est == Endpoint.TAIL && ep2Est == Endpoint.ARROW) { // Est: ->
this.fp++;
}
}
}
}
}

public int getTp() {
return tp;
}

public int getTn() {
return tn;
}

public int getFp() {
return fp;
}

public int getFn() {
return fn;
}
}
30 changes: 26 additions & 4 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.algcomparison.statistic.AdjacencyPrecision;
import edu.cmu.tetrad.algcomparison.statistic.AdjacencyRecall;
import edu.cmu.tetrad.algcomparison.statistic.ArrowheadPrecision;
import edu.cmu.tetrad.algcomparison.statistic.ArrowheadRecall;
import edu.cmu.tetrad.algcomparison.statistic.*;
import edu.cmu.tetrad.data.GeneralAndersonDarlingTest;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Graph;
Expand Down Expand Up @@ -332,6 +329,31 @@ public void getPrecisionAndRecallOnMarkovBlanketGraph(Node x, Graph estimatedGra
" ArrowHeadPrecision = " + nf.format(ahp) + " ArrowHeadRecall = " + nf.format(ahr));
}

/**
* Calculates the precision and recall using LocalGraphConfusion
* (which calculates the combination of Adjacency and ArrowHead) on the Markov Blanket graph for a given node.
* Prints the statistics to the console.
*
* @param x The target node.
* @param estimatedGraph The estimated graph.
* @param trueGraph The true graph.
*/
public void getPrecisionAndRecallOnMarkovBlanketGraph2(Node x, Graph estimatedGraph, Graph trueGraph) {
// Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes.
Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes());
Graph xMBLookupGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(lookupGraph, x);
System.out.println("xMBLookupGraph:" + xMBLookupGraph);
Graph xMBEstimatedGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(estimatedGraph, x);
System.out.println("xMBEstimatedGraph:" + xMBEstimatedGraph);

double lgp = new LocalGraphPrecision().getValue(xMBLookupGraph, xMBEstimatedGraph, null);
double lgr = new LocalGraphRecall().getValue(xMBLookupGraph, xMBEstimatedGraph, null);

NumberFormat nf = new DecimalFormat("0.00");
System.out.println("Node " + x + "'s statistics: " + " \n" +
" LocalGraphPrecision = " + nf.format(lgp) + " LocalGraphRecall = " + nf.format(lgr) + " \n");
}

/**
* Returns the variables of the independence test.
*
Expand Down
38 changes: 38 additions & 0 deletions tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java
Original file line number Diff line number Diff line change
Expand Up @@ -446,4 +446,42 @@ public void testNonGaussianCPDAGPrecisionRecallForLocalOnParents() {
}
}

@Test
public void testDAGPrecisionRecall2ForLocalOnMarkovBlanket() {
Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false);
System.out.println("Test True Graph: " + trueGraph);
System.out.println("Test True Graph size: " + trueGraph.getNodes().size());

SemPm pm = new SemPm(trueGraph);
SemIm im = new SemIm(pm, new Parameters());
DataSet data = im.simulateData(1000, false);
edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false);
score.setPenaltyDiscount(2);
Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search();
System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag);
System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~");

IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05);
MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET);
List<List<Node>> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05);
List<Node> accepts = accepts_rejects.get(0);
List<Node> rejects = accepts_rejects.get(1);
System.out.println("Accepts size: " + accepts.size());
System.out.println("Rejects size: " + rejects.size());

List<Double> acceptsPrecision = new ArrayList<>();
List<Double> acceptsRecall = new ArrayList<>();
for(Node a: accepts) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph);
System.out.println("=====================");

}
for (Node a: rejects) {
System.out.println("=====================");
markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph);
System.out.println("=====================");
}
}

}