Skip to content

Commit

Permalink
Merge pull request #164 from cmu-phil/autistics
Browse files Browse the repository at this point in the history
autistics
  • Loading branch information
jdramsey committed Mar 18, 2016
2 parents b1b0f02 + fd183f7 commit 55d000e
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,7 @@ public void printStuffForKlea() {

out1.close();
out2.close();
}
catch (Exception e) {
} catch (Exception e) {

}

Expand Down Expand Up @@ -738,7 +737,10 @@ public void testFgs(int numVars, double edgesPerNode, int numCases, double penal

public void testFgsDiscrete(int numVars, double edgeFactor, int numCases,
double structurePrior, double samplePrior) {
init(new File("long.FGSDiscrete." + numVars + ".txt"), "Tests performance of the FGS algorithm");
init(new File("long.FGSDiscrete:" + numVars + ":" +
(int) (numVars * edgeFactor) + ":" +
structurePrior + ":" +
samplePrior + ".txt"), "Tests performance of the FGS algorithm");

long time1 = System.currentTimeMillis();

Expand Down Expand Up @@ -911,7 +913,8 @@ public void testSaveLoadDataSerial(int numVars, int numCases) {
vars.add(new ContinuousVariable("X" + i));
}

Graph graph = GraphUtils.randomGraphRandomForwardEdges(vars, 0, numVars, 30, 15, 15, false);;
Graph graph = GraphUtils.randomGraphRandomForwardEdges(vars, 0, numVars, 30, 15, 15, false);
;

out.println("Graph generated.");
// Graph graph = new EndpointMatrixGraph(DataGraphUtils.randomDagQuick(vars, 0, numVars));
Expand Down Expand Up @@ -954,6 +957,141 @@ public void testSaveLoadDataSerial(int numVars, int numCases) {
}
}

public void testFgsComparison(int numVars, double edgesPerNode, int numCases, int numRuns) {
double penaltyDiscount = 4.0;
int depth = 3;

init(new File("fgs.comparison" + numVars + "." + (int) (edgesPerNode * numVars) +
"." + numCases + "." + numRuns + ".txt"), "Num runs = " + numRuns);
out.println("Num vars = " + numVars);
out.println("Num edges = " + (int) (numVars * edgesPerNode));
out.println("Num cases = " + numCases);
out.println("Penalty discount = " + penaltyDiscount);
out.println("Depth = " + depth);
out.println();

List<int[][]> counts = new ArrayList<int[][]>();
List<double[]> arrowStats = new ArrayList<>();
List<double[]> tailStats = new ArrayList<>();
List<Double> degrees = new ArrayList<>();
List<Long> elapsedTimes = new ArrayList<Long>();

for (int run = 0; run < numRuns; run++) {

out.println("\n\n\n******************************** RUN " + (run + 1) + " ********************************\n\n");

System.out.println("Making dag");

Graph dag = makeDag(numVars, edgesPerNode);

Graph pattern = SearchGraphUtils.patternForDag(dag);

List<Node> vars = dag.getNodes();

int[] causalOrdering = new int[vars.size()];

for (int i = 0; i < vars.size(); i++) {
causalOrdering[i] = i;
}

System.out.println("Graph done");

long time1 = System.currentTimeMillis();

out.println("Graph done");

System.out.println("Starting simulation");

LargeSemSimulator simulator = new LargeSemSimulator(dag, vars, causalOrdering);
simulator.setOut(out);

DataSet data = simulator.simulateDataAcyclic(numCases);

System.out.println("Finishing simulation");

long time2 = System.currentTimeMillis();

out.println("Elapsed (simulating the data): " + (time2 - time1) + " ms");

System.out.println("Making covariance matrix");

ICovarianceMatrix cov = new CovarianceMatrixOnTheFly(data);

System.out.println("Covariance matrix done");


long time3 = System.currentTimeMillis();

out.println("Elapsed (calculating cov): " + (time3 - time2) + " ms\n");

SemBicScore score = new SemBicScore(cov, penaltyDiscount);

Fgs fgs = new Fgs(score);
fgs.setVerbose(false);
fgs.setNumPatternsToStore(0);
fgs.setPenaltyDiscount(penaltyDiscount);
fgs.setOut(System.out);
fgs.setFaithfulnessAssumed(true);
fgs.setDepth(-1);
fgs.setCycleBound(5);

System.out.println("\nStarting FGS");

Graph estPattern = fgs.search();

System.out.println("Done with FGS");

long time4 = System.currentTimeMillis();

out.println(new Date());

System.out.println("Making list of vars");

estPattern = GraphUtils.replaceNodes(estPattern, dag.getNodes());

double degree = GraphUtils.degree(estPattern);
degrees.add(degree);

System.out.println("Degree out output graph = " + degree);

arrowStats.add(printCorrectArrows(dag, estPattern, pattern));
tailStats.add(printCorrectTails(dag, estPattern, pattern));

counts.add(SearchGraphUtils.graphComparison(estPattern, pattern, out));

long elapsed = time4 - time3;
elapsedTimes.add(elapsed);
out.println("\nElapsed: " + elapsed + " ms");

directedComparison(dag, pattern, estPattern);

try {
PrintStream out2 = new PrintStream(new File("dag." + run + ".txt"));
out2.println(dag);

PrintStream out3 = new PrintStream(new File("estpag." + run + ".txt"));
out3.println(estPattern);

PrintStream out4 = new PrintStream(new File("truepag." + run + ".txt"));
out4.println(pattern);

out2.close();
out3.close();
out4.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
throw new RuntimeException(e);
}
}

printAverageConfusion("Average", counts);
printAverageStatistics("Average", arrowStats, tailStats, elapsedTimes, degrees);

out.close();

}


public void testGFciComparison(int numVars, double edgesPerNode, int numCases, int numLatents) {
numVars = 1000;
edgesPerNode = 1.0;
Expand Down Expand Up @@ -1097,7 +1235,7 @@ public void testGFciComparison(int numVars, double edgesPerNode, int numCases, i
}

printAverageConfusion("Average", ffciCounts);
printAverageStatistics("Average", ffciArrowStats, ffciTailStats, ffciElapsedTimes);
printAverageStatistics("Average", ffciArrowStats, ffciTailStats, ffciElapsedTimes, new ArrayList<Double>());

out.close();

Expand Down Expand Up @@ -1686,7 +1824,7 @@ private void bidirectedComparison(Graph dag, Graph truePag, Graph estGraph, Set<
}

private void printAverageStatistics(String name, List<double[]> arrowStats, List<double[]> tailStats,
List<Long> elapsedTimes) {
List<Long> elapsedTimes, List<Double> degrees) {
NumberFormat nf =
new DecimalFormat("0");
NumberFormat nf2 = new DecimalFormat("0.00");
Expand Down Expand Up @@ -1726,13 +1864,22 @@ private void printAverageStatistics(String name, List<double[]> arrowStats, List
avgTailStats[i] = sum / (double) tailStats.size();
}

double sumDegrees = 0;

for (int i = 0; i < degrees.size(); i++) {
sumDegrees += degrees.get(i);
}

double avgDegree = sumDegrees / degrees.size();

out.println();
out.println("Avg Correct Tails = " + nf.format(avgTailStats[0]));
out.println("Avg Estimated Tails = " + nf.format(avgTailStats[1]));
out.println("Avg True Tails = " + nf.format(avgTailStats[2]));
out.println("Avg Tail Precision = " + nf2.format(avgTailStats[3]));
out.println("Avg Tail Recall = " + nf2.format(avgTailStats[4]));
out.println("Avg Proportion Correct Ancestor Relationships = " + nf2.format(avgTailStats[5]));
out.println("Avg Max Degree of Output Pattern = " + nf2.format(avgDegree));

double sumElapsed = 0;

Expand Down Expand Up @@ -2057,6 +2204,15 @@ public static void main(String... args) {
performanceTests.testFgs(numVars, edgeFactor, numCases, penaltyDiscount);
break;
}
case "TestFgsComparison": {
final int numVars = Integer.parseInt(args[1]);
final double edgeFactor = Double.parseDouble(args[2]);
final int numCases = Integer.parseInt(args[3]);
final int numRuns = Integer.parseInt(args[4]);

performanceTests.testFgsComparison(numVars, edgeFactor, numCases, numRuns);
break;
}
default:
throw new IllegalArgumentException("Not a configuration: ");
}
Expand Down
26 changes: 13 additions & 13 deletions tetrad-lib/src/main/java/edu/cmu/tetrad/search/BDeuScore.java
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public double localScore(int node, int parents[]) {
// if (_score != null) return _score;

// Number of categories for node.
int r = numCategories[node];
int c = numCategories[node];

// Numbers of categories of parents.
int[] dims = new int[parents.length];
Expand All @@ -108,15 +108,15 @@ public double localScore(int node, int parents[]) {
}

// Number of parent states.
int q = 1;
int r = 1;

for (int p = 0; p < parents.length; p++) {
q *= dims[p];
r *= dims[p];
}

// Conditional cell coefs of data for node given parents(node).
int n_jk[][] = new int[q][r];
int n_j[] = new int[q];
int n_jk[][] = new int[r][c];
int n_j[] = new int[r];

int[] parentValues = new int[parents.length];

Expand Down Expand Up @@ -149,22 +149,22 @@ public double localScore(int node, int parents[]) {
//Finally, compute the score
double score = 0.0;

// score += r * q * FastMath.log(getStructurePrior());
// score += c * r * FastMath.log(getStructurePrior());
score += getPriorForStructure(parents.length);

final double cellPrior = getSamplePrior() / (r * q);
final double rowPrior = getSamplePrior() / q;
final double cellPrior = getSamplePrior() / (c * r);
final double rowPrior = getSamplePrior() / r;

for (int j = 0; j < q; j++) {
for (int j = 0; j < r; j++) {
score -= Gamma.logGamma(rowPrior + n_j[j]);

for (int k = 0; k < r; k++) {
for (int k = 0; k < c; k++) {
score += Gamma.logGamma(cellPrior + n_jk[j][k]);
}
}

score += q * Gamma.logGamma(rowPrior);
score -= r * q * Gamma.logGamma(cellPrior);
score += r * Gamma.logGamma(rowPrior);
score -= c * r * Gamma.logGamma(cellPrior);

// if (parents.length <= 1) {
// all.put(asList(node, parents), score);
Expand All @@ -185,7 +185,7 @@ private double getPriorForStructure(int numParents) {
double e = getStructurePrior();
int k = numParents;
int vm = data.length - 1;
return Math.log(e / (vm)) + (vm - k) * Math.log(1.0 - (e / (vm)));
return k * Math.log(e / (vm)) + (vm - k) * Math.log(1.0 - (e / (vm)));
}

private double getPriorForStructure2(int numParents) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,17 @@ public double localScoreDiff(int x, int y, int[] z) {
*/
public double localScore(int i, int[] parents) {
double sum = 0.0;
int count = 0;

for (SemBicScore score : semBicScores) {
sum += score.localScore(i, parents);
double _score = score.localScore(i, parents);

if (!Double.isNaN(_score)) {
sum += _score;
}
}

return sum / semBicScores.size();
return sum / count;
}

public double localScore(int i, int[] parents, int index) {
Expand Down
4 changes: 2 additions & 2 deletions tetrad-lib/src/test/java/edu/cmu/tetrad/test/TestFgs.java
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ public void explore2() {
{0, 0, 0, 0, 0, 0},
{0, 0, 0, 0, 0, 0},
{0, 0, 0, 0, 0, 0},
{0, 0, 0, 15, 0, 1},
{0, 0, 0, 1, 0, 0},
{1, 0, 0, 13, 0, 3},
{0, 0, 0, 0, 0, 0},
{0, 0, 0, 0, 0, 0},
{0, 0, 0, 0, 0, 0},
};
Expand Down

0 comments on commit 55d000e

Please sign in to comment.