diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/LocalGraphPrecision.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/LocalGraphPrecision.java new file mode 100644 index 0000000000..cdc2c3b57c --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/LocalGraphPrecision.java @@ -0,0 +1,30 @@ +package edu.cmu.tetrad.algcomparison.statistic; + +import edu.cmu.tetrad.algcomparison.statistic.utils.LocalGraphConfusion; +import edu.cmu.tetrad.data.DataModel; +import edu.cmu.tetrad.graph.Graph; + +public class LocalGraphPrecision implements Statistic { + @Override + public String getAbbreviation() { + return "LGP"; + } + + @Override + public String getDescription() { + return "Local Graph Precision"; + } + + @Override + public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { + LocalGraphConfusion lgConfusion = new LocalGraphConfusion(trueGraph, estGraph); + int lgTp = lgConfusion.getTp(); + int lgFp = lgConfusion.getFp(); + return lgTp / (double) (lgTp + lgFp); + } + + @Override + public double getNormValue(double value) { + return value; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/LocalGraphRecall.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/LocalGraphRecall.java new file mode 100644 index 0000000000..94b893d248 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/LocalGraphRecall.java @@ -0,0 +1,30 @@ +package edu.cmu.tetrad.algcomparison.statistic; + +import edu.cmu.tetrad.algcomparison.statistic.utils.LocalGraphConfusion; +import edu.cmu.tetrad.data.DataModel; +import edu.cmu.tetrad.graph.Graph; + +public class LocalGraphRecall implements Statistic { + @Override + public String getAbbreviation() { + return "LGR"; + } + + @Override + public String getDescription() { + return "Local Graph Recall"; + } + + @Override + public double getValue(Graph trueGraph, Graph estGraph, DataModel dataModel) { + LocalGraphConfusion lgConfusion = new LocalGraphConfusion(trueGraph, estGraph); + int lgTp = lgConfusion.getTp(); + int lgFn = lgConfusion.getFn(); + return lgTp / (double) (lgTp + lgFn); + } + + @Override + public double getNormValue(double value) { + return value; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/LocalGraphConfusion.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/LocalGraphConfusion.java new file mode 100644 index 0000000000..c10e7c2da4 --- /dev/null +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/algcomparison/statistic/utils/LocalGraphConfusion.java @@ -0,0 +1,224 @@ +package edu.cmu.tetrad.algcomparison.statistic.utils; + +import edu.cmu.tetrad.graph.*; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * A confusion matrix for local graph accuracy check --i.e. TP, FP, TN, FN for counts of a combination of + * arrowhead and precision. + */ +public class LocalGraphConfusion { + /** + * The true positive count. + */ + private int tp; + + /** + * The true negative count. + */ + private int tn; + + /** + * The false positive count. + */ + private int fp; + + /** + * The false positive count. + */ + private int fn; + + /** + * Constructs a new LocalGraphConfusion object from the given graphs. + * @param trueGraph The true graph + * + * @param estGraph The estimated graph + */ + public LocalGraphConfusion(Graph trueGraph, Graph estGraph) { + this.tp = 0; + this.tn = 0; + this.fp = 0; + this.fn = 0; + + // STEP0: Create lookups for both true graph and estimated graph. + // trueGraphLookup is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes. + Graph trueGraphLookup = GraphUtils.replaceNodes(trueGraph, estGraph.getNodes()); + // estGraphLookup is the same structure as estGraph's structure but node objects replaced by true graph nodes. + Graph estGraphLookup = GraphUtils.replaceNodes(estGraph, trueGraph.getNodes()); + + // STEP1: Check for Adjacency. + /** + * True + * Y N + * --------------------- + * Y | TP FP + * Est | -------------------- + * N | FN TN + * ----------------------- + */ + // STEP 1.1: Create allUnoriented base on trueGraphLookup and estimatedGraph + Set allUnoriented = new HashSet<>(); + for (Edge edge: trueGraphLookup.getEdges()) { + allUnoriented.add(Edges.undirectedEdge(edge.getNode1(), edge.getNode2())); + } + for (Edge edge: estGraph.getEdges()) { + allUnoriented.add(Edges.undirectedEdge(edge.getNode1(), edge.getNode2())); + } + // STEP 1.2: Iterate through allUnoriented to record confusion metrix + for (Edge u: allUnoriented) { + Node node1 = u.getNode1(); + Node node2 = u.getNode2(); + if (estGraph.isAdjacentTo(node1, node2)) { // Est: Y + if (trueGraphLookup.isAdjacentTo(node1, node2)) { // True: Y + this.tp++; + } else { // True: N + this.fp++; + } + } else { // Est: N + if (trueGraphLookup.isAdjacentTo(node1, node2)) { // True: Y + this.fn++; + } else { // True: N + this.tn++; + } + } + } + + // STEP2: Check for Orientation(i.e. Arrowhead), so we need to check both endpoints of an edge. + /** + * True + * -> <- ...(None) + * --------------------------- + * -> | TP FP,FN / (Do not repeat count, as we checked for it in Adj step) + * Est | -------------------------- + * <- | FP, FN TP / + * | -------------------------- + * -- | FN FN / + * | -------------------------- + * ...| / / / + * ----------------------------- + * + */ + // STEP2.1: Check through the true graph + for (Edge tle: trueGraphLookup.getEdges()) { + // STEP2.1.1: Get corresponding endpoint in Est graph lookup + List estGraphLookupEdges = estGraphLookup.getEdges(tle.getNode1(), tle.getNode2()); + Edge ele; // estimated lookup graph edge + if (estGraphLookupEdges.size() == 1) { + ele = estGraphLookupEdges.iterator().next(); + } else { + ele = estGraphLookup.getDirectedEdge(tle.getNode1(), tle.getNode2()); + } + Endpoint ep1Est = null; + Endpoint ep2Est = null; + if (ele != null) { + ep1Est = ele.getProximalEndpoint(tle.getNode1()); + ep2Est = ele.getProximalEndpoint(tle.getNode2()); + } + + // STEP2.1.2: Get corresponding endpoint in true graph lookup + List trueGraphLookupEdges = trueGraphLookup.getEdges(tle.getNode1(), tle.getNode1()); + Edge tle2; + if (trueGraphLookupEdges.size() == 1) { + tle2 = trueGraphLookupEdges.iterator().next(); + } else { + tle2 = trueGraphLookup.getDirectedEdge(tle.getNode1(), tle.getNode2()); + } + Endpoint ep1True = null; + Endpoint ep2True = null; + if (tle2 != null) { + ep1True = tle2.getProximalEndpoint(tle.getNode1()); + ep2True = tle2.getProximalEndpoint(tle.getNode2()); + } + + // STEP2.1.3: Compare the endpoints + // we only care the case when the edge exist. + boolean connected = trueGraph.isAdjacentTo(tle.getNode1(), tle.getNode2()) + && estGraph.isAdjacentTo(tle.getNode1(), tle.getNode2()); + if (connected) { + if (ep1True == Endpoint.TAIL && ep2True == Endpoint.ARROW) { // True: -> + if (ep1Est == Endpoint.TAIL && ep2Est == Endpoint.ARROW) { // Est: -> + this.tp++; + } else if (ep1Est == Endpoint.ARROW && ep2Est == Endpoint.TAIL) { // Est: <- + // this.fp++; + this.fn++; + } else if (ep1Est == Endpoint.TAIL && ep2Est == Endpoint.TAIL) { // Est: -- + this.fn++; + } + } else if (ep1True == Endpoint.ARROW && ep2True == Endpoint.TAIL) { // True: <- + if (ep1Est == Endpoint.TAIL && ep2Est == Endpoint.ARROW) { // Est: -> + // this.fp++; + this.fn++; + } else if (ep1Est == Endpoint.ARROW && ep2Est == Endpoint.TAIL) { // Est: <- + this.tp++; + } else if (ep1Est == Endpoint.TAIL && ep2Est == Endpoint.TAIL) { // Est: -- + this.fn++; + } + } + } + } + // STEP2: Check through the est graph + // because est graph can have extra arrowhead that was not in true graph, which should be count as fp. + for (Edge ele: estGraphLookup.getEdges()) { + List estGraphLookupEdges = estGraphLookup.getEdges(ele.getNode1(), ele.getNode2()); + Edge ele2; + if (estGraphLookupEdges.size() == 1) { + ele2 = estGraphLookupEdges.iterator().next(); + } else { + ele2 = estGraphLookup.getDirectedEdge(ele.getNode1(), ele.getNode2()); + } + Endpoint ep1Est = null; + Endpoint ep2Est = null; + if (ele2 != null) { + ep1Est = ele2.getProximalEndpoint(ele.getNode1()); + ep2Est = ele2.getProximalEndpoint(ele.getNode2()); + } + + List trueGraphLookupEdges = trueGraphLookup.getEdges(ele.getNode1(), ele.getNode1()); + Edge tle; + if (trueGraphLookupEdges.size() == 1) { + tle = trueGraphLookupEdges.iterator().next(); + } else { + tle = trueGraphLookup.getDirectedEdge(ele.getNode1(), ele.getNode2()); + } + Endpoint ep1True = null; + Endpoint ep2True = null; + if (tle != null) { + ep1True = tle.getProximalEndpoint(ele.getNode1()); + ep2True = tle.getProximalEndpoint(ele.getNode2()); + } + + boolean connected = trueGraph.isAdjacentTo(ele.getNode1(), ele.getNode2()); + if (connected) { + if (ep1True == Endpoint.TAIL && ep2True == Endpoint.ARROW) { // True: -> + if (ep1Est == Endpoint.ARROW && ep2Est == Endpoint.TAIL) { // Est: <- + this.fp++; + } + // TODO VBC: Question: seems we wont encounter <-> case, is it? + } else if (ep1True == Endpoint.ARROW && ep2True == Endpoint.TAIL) { // True: <- + if (ep1Est == Endpoint.TAIL && ep2Est == Endpoint.ARROW) { // Est: -> + this.fp++; + } + } + } + } + } + + public int getTp() { + return tp; + } + + public int getTn() { + return tn; + } + + public int getFp() { + return fp; + } + + public int getFn() { + return fn; + } +} diff --git a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java index f29e0d7cf6..510e043ed2 100644 --- a/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java +++ b/tetrad-lib/src/main/java/edu/cmu/tetrad/search/MarkovCheck.java @@ -1,9 +1,6 @@ package edu.cmu.tetrad.search; -import edu.cmu.tetrad.algcomparison.statistic.AdjacencyPrecision; -import edu.cmu.tetrad.algcomparison.statistic.AdjacencyRecall; -import edu.cmu.tetrad.algcomparison.statistic.ArrowheadPrecision; -import edu.cmu.tetrad.algcomparison.statistic.ArrowheadRecall; +import edu.cmu.tetrad.algcomparison.statistic.*; import edu.cmu.tetrad.data.GeneralAndersonDarlingTest; import edu.cmu.tetrad.data.Knowledge; import edu.cmu.tetrad.graph.Graph; @@ -332,6 +329,31 @@ public void getPrecisionAndRecallOnMarkovBlanketGraph(Node x, Graph estimatedGra " ArrowHeadPrecision = " + nf.format(ahp) + " ArrowHeadRecall = " + nf.format(ahr)); } + /** + * Calculates the precision and recall using LocalGraphConfusion + * (which calculates the combination of Adjacency and ArrowHead) on the Markov Blanket graph for a given node. + * Prints the statistics to the console. + * + * @param x The target node. + * @param estimatedGraph The estimated graph. + * @param trueGraph The true graph. + */ + public void getPrecisionAndRecallOnMarkovBlanketGraph2(Node x, Graph estimatedGraph, Graph trueGraph) { + // Lookup graph is the same structure as trueGraph's structure but node objects replaced by estimated graph nodes. + Graph lookupGraph = GraphUtils.replaceNodes(trueGraph, estimatedGraph.getNodes()); + Graph xMBLookupGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(lookupGraph, x); + System.out.println("xMBLookupGraph:" + xMBLookupGraph); + Graph xMBEstimatedGraph = GraphUtils.getMarkovBlanketSubgraphWithTargetNode(estimatedGraph, x); + System.out.println("xMBEstimatedGraph:" + xMBEstimatedGraph); + + double lgp = new LocalGraphPrecision().getValue(xMBLookupGraph, xMBEstimatedGraph, null); + double lgr = new LocalGraphRecall().getValue(xMBLookupGraph, xMBEstimatedGraph, null); + + NumberFormat nf = new DecimalFormat("0.00"); + System.out.println("Node " + x + "'s statistics: " + " \n" + + " LocalGraphPrecision = " + nf.format(lgp) + " LocalGraphRecall = " + nf.format(lgr) + " \n"); + } + /** * Returns the variables of the independence test. * diff --git a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java index c0c381b54f..27fdf9b703 100644 --- a/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java +++ b/tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestCheckMarkov.java @@ -446,4 +446,42 @@ public void testNonGaussianCPDAGPrecisionRecallForLocalOnParents() { } } + @Test + public void testDAGPrecisionRecall2ForLocalOnMarkovBlanket() { + Graph trueGraph = RandomGraph.randomDag(10, 0, 10, 100, 100, 100, false); + System.out.println("Test True Graph: " + trueGraph); + System.out.println("Test True Graph size: " + trueGraph.getNodes().size()); + + SemPm pm = new SemPm(trueGraph); + SemIm im = new SemIm(pm, new Parameters()); + DataSet data = im.simulateData(1000, false); + edu.cmu.tetrad.search.score.SemBicScore score = new SemBicScore(data, false); + score.setPenaltyDiscount(2); + Graph estimatedCpdag = new PermutationSearch(new Boss(score)).search(); + System.out.println("Test Estimated CPDAG Graph: " + estimatedCpdag); + System.out.println("~~~~~~~~~~~~~~~~~~~~~~~~~~~~"); + + IndependenceTest fisherZTest = new IndTestFisherZ(data, 0.05); + MarkovCheck markovCheck = new MarkovCheck(estimatedCpdag, fisherZTest, ConditioningSetType.MARKOV_BLANKET); + List> accepts_rejects = markovCheck.getAndersonDarlingTestAcceptsRejectsNodesForAllNodes(fisherZTest, estimatedCpdag, 0.05); + List accepts = accepts_rejects.get(0); + List rejects = accepts_rejects.get(1); + System.out.println("Accepts size: " + accepts.size()); + System.out.println("Rejects size: " + rejects.size()); + + List acceptsPrecision = new ArrayList<>(); + List acceptsRecall = new ArrayList<>(); + for(Node a: accepts) { + System.out.println("====================="); + markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph); + System.out.println("====================="); + + } + for (Node a: rejects) { + System.out.println("====================="); + markovCheck.getPrecisionAndRecallOnMarkovBlanketGraph2(a, estimatedCpdag, trueGraph); + System.out.println("====================="); + } + } + }