Skip to content

Commit

Permalink
[SYSTEMDS-3714] N-gram statistics of operation sequences
Browse files Browse the repository at this point in the history
Closes #2045.
  • Loading branch information
Jaybit0 authored and mboehm7 committed Jul 24, 2024
1 parent 195f83b commit 9f96718
Show file tree
Hide file tree
Showing 7 changed files with 529 additions and 4 deletions.
27 changes: 27 additions & 0 deletions src/main/java/org/apache/sysds/api/DMLOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ public class DMLOptions {
public String configFile = null; // Path to config file if default config and default config is to be overridden
public boolean clean = false; // Whether to clean up all SystemDS working directories (FS, DFS)
public boolean stats = false; // Whether to record and print the statistics
public boolean statsNGrams = false; // Whether to record and print the statistics n-grams
public int statsCount = 10; // Default statistics count
public int[] statsNGramSizes = { 3 }; // Default n-gram tuple sizes
public int statsTopKNGrams = 10; // How many of the most heavy hitting n-grams are displayed
public boolean fedStats = false; // Whether to record and print the federated statistics
public int fedStatsCount = 10; // Default federated statistics count
public boolean memStats = false; // max memory statistics
Expand Down Expand Up @@ -212,6 +215,26 @@ else if (lineageType.equalsIgnoreCase("debugger"))
}
}
}

dmlOptions.statsNGrams = line.hasOption("ngrams");
if (dmlOptions.statsNGrams){
String[] nGramArgs = line.getOptionValues("ngrams");
if (nGramArgs.length == 2) {
try {
String[] nGramSizeSplit = nGramArgs[0].split(",");
dmlOptions.statsNGramSizes = new int[nGramSizeSplit.length];

for (int i = 0; i < nGramSizeSplit.length; i++) {
dmlOptions.statsNGramSizes[i] = Integer.parseInt(nGramSizeSplit[i]);
}

dmlOptions.statsTopKNGrams = Integer.parseInt(nGramArgs[1]);
} catch (NumberFormatException e) {
throw new org.apache.commons.cli.ParseException("Invalid argument specified for -ngrams option, must be a valid integer");
}
}
}

dmlOptions.fedStats = line.hasOption("fedStats");
if (dmlOptions.fedStats) {
String fedStatsCount = line.getOptionValue("fedStats");
Expand Down Expand Up @@ -335,6 +358,9 @@ private static Options createCLIOptions() {
Option statsOpt = OptionBuilder.withArgName("count")
.withDescription("monitors and reports summary execution statistics; heavy hitter <count> is 10 unless overridden; default off")
.hasOptionalArg().create("stats");
Option ngramsOpt = OptionBuilder//.withArgName("ngrams")
.withDescription("monitors and reports the most occurring n-grams; -ngrams <comma separated n's> <topK>")
.hasOptionalArgs(2).create("ngrams");
Option fedStatsOpt = OptionBuilder.withArgName("count")
.withDescription("monitors and reports summary execution statistics of federated workers; heavy hitter <count> is 10 unless overridden; default off")
.hasOptionalArg().create("fedStats");
Expand Down Expand Up @@ -396,6 +422,7 @@ private static Options createCLIOptions() {
options.addOption(configOpt);
options.addOption(cleanOpt);
options.addOption(statsOpt);
options.addOption(ngramsOpt);
options.addOption(fedStatsOpt);
options.addOption(memOpt);
options.addOption(explainOpt);
Expand Down
9 changes: 9 additions & 0 deletions src/main/java/org/apache/sysds/api/DMLScript.java
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,16 @@ public class DMLScript
private static ExecMode EXEC_MODE = DMLOptions.defaultOptions.execMode;
// Enable/disable to print statistics
public static boolean STATISTICS = DMLOptions.defaultOptions.stats;
// Enable/disable to print statistics n-grams
public static boolean STATISTICS_NGRAMS = DMLOptions.defaultOptions.statsNGrams;
// Enable/disable to gather memory use stats in JMLC
public static boolean JMLC_MEM_STATISTICS = false;
// Set maximum heavy hitter count
public static int STATISTICS_COUNT = DMLOptions.defaultOptions.statsCount;
// The sizes of recorded n-gram tuples
public static int[] STATISTICS_NGRAM_SIZES = DMLOptions.defaultOptions.statsNGramSizes;
// Set top k displayed n-grams limit
public static int STATISTICS_TOP_K_NGRAMS = DMLOptions.defaultOptions.statsTopKNGrams;
// Set statistics maximum wrap length
public static int STATISTICS_MAX_WRAP_LEN = 30;
// Enable/disable to print federated statistics
Expand Down Expand Up @@ -250,6 +256,9 @@ public static boolean executeScript( String[] args )
{
STATISTICS = dmlOptions.stats;
STATISTICS_COUNT = dmlOptions.statsCount;
STATISTICS_NGRAMS = dmlOptions.statsNGrams;
STATISTICS_NGRAM_SIZES = dmlOptions.statsNGramSizes;
STATISTICS_TOP_K_NGRAMS = dmlOptions.statsTopKNGrams;
FED_STATISTICS = dmlOptions.fedStats;
FED_STATISTICS_COUNT = dmlOptions.fedStatsCount;
JMLC_MEM_STATISTICS = dmlOptions.memStats;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ protected ScalarObject executePredicateInstructions(ArrayList<Instruction> inst,
private void executeSingleInstruction(Instruction currInst, ExecutionContext ec) {
try {
// start time measurement for statistics
long t0 = (DMLScript.STATISTICS || LOG.isTraceEnabled()) ? System.nanoTime() : 0;
long t0 = (DMLScript.STATISTICS || DMLScript.STATISTICS_NGRAMS || LOG.isTraceEnabled()) ? System.nanoTime() : 0;

// pre-process instruction (inst patching, listeners, lineage)
Instruction tmp = currInst.preprocessInstruction(ec);
Expand All @@ -263,6 +263,10 @@ private void executeSingleInstruction(Instruction currInst, ExecutionContext ec)
if(DMLScript.STATISTICS) {
Statistics.maintainCPHeavyHitters(tmp.getExtendedOpcode(), System.nanoTime() - t0);
}

if (DMLScript.STATISTICS_NGRAMS) {
Statistics.maintainNGrams(tmp.getExtendedOpcode(), System.nanoTime() - t0);
}
}

// optional trace information (instruction and runtime)
Expand Down
231 changes: 228 additions & 3 deletions src/main/java/org/apache/sysds/utils/Statistics.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,21 +34,25 @@
import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
import org.apache.sysds.utils.stats.CodegenStatistics;
import org.apache.sysds.utils.stats.RecompileStatistics;
import org.apache.sysds.utils.stats.NGramBuilder;
import org.apache.sysds.utils.stats.NativeStatistics;
import org.apache.sysds.utils.stats.ParamServStatistics;
import org.apache.sysds.utils.stats.ParForStatistics;
import org.apache.sysds.utils.stats.ParamServStatistics;
import org.apache.sysds.utils.stats.RecompileStatistics;
import org.apache.sysds.utils.stats.SparkStatistics;
import org.apache.sysds.utils.stats.TransformStatistics;

import java.lang.management.CompilationMXBean;
import java.lang.management.GarbageCollectorMXBean;
import java.lang.management.ManagementFactory;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -64,14 +68,55 @@ private static class InstStats {
private final LongAdder time = new LongAdder();
private final LongAdder count = new LongAdder();
}

public static class NGramStats {

public final long n;
public final long cumTimeNanos;
public final double m2;

public static <T> Comparator<NGramBuilder.NGramEntry<T, NGramStats>> getComparator() {
return Comparator.comparingLong(entry -> entry.getCumStats().cumTimeNanos);
}

public static NGramStats merge(NGramStats stats1, NGramStats stats2) {
// Using the algorithm from: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
long newN = stats1.n + stats2.n;
long cumTimeNanos = stats1.cumTimeNanos + stats2.cumTimeNanos;

// Ensure the calculation uses floating-point arithmetic
double mean1 = (double) stats1.cumTimeNanos / 1000000000d / stats1.n;
double mean2 = (double) stats2.cumTimeNanos / 1000000000d / stats2.n;
double delta = mean2 - mean1;

double newM2 = stats1.m2 + stats2.m2 + delta * delta * stats1.n * stats2.n / (double)newN;

return new NGramStats(newN, cumTimeNanos, newM2);
}

public NGramStats(final long n, final long cumTimeNanos, final double m2) {
this.n = n;
this.cumTimeNanos = cumTimeNanos;
this.m2 = m2;
}

public double getTimeVariance() {
return m2 / Math.max(n-1, 1);
}

public String toString() {
return String.format(Locale.US, "%.5f", (cumTimeNanos / 1000000000d));
}
}

private static long compileStartTime = 0;
private static long compileEndTime = 0;
private static long execStartTime = 0;
private static long execEndTime = 0;

//heavy hitter counts and times
private static final ConcurrentHashMap<String,InstStats>_instStats = new ConcurrentHashMap<>();
private static final ConcurrentHashMap<String,InstStats> _instStats = new ConcurrentHashMap<>();
private static final ConcurrentHashMap<String, NGramBuilder<String, NGramStats>[]> _instStatsNGram = new ConcurrentHashMap<>();

// number of compiled/executed SP instructions
private static final LongAdder numExecutedSPInst = new LongAdder();
Expand Down Expand Up @@ -252,6 +297,8 @@ public static void reset()
DMLCompressionStatistics.reset();

FederatedStatistics.reset();

_instStatsNGram.clear();
}

public static void resetJITCompileTime(){
Expand Down Expand Up @@ -353,6 +400,177 @@ public static void maintainCPHeavyHitters( String instName, long timeNanos ) {
tmp.time.add(timeNanos);
tmp.count.increment();
}

public static void maintainNGrams(String instName, long timeNanos) {
NGramBuilder<String, NGramStats>[] tmp = _instStatsNGram.computeIfAbsent(Thread.currentThread().getName(), k -> {
NGramBuilder<String, NGramStats>[] threadEntry = new NGramBuilder[DMLScript.STATISTICS_NGRAM_SIZES.length];
for (int i = 0; i < threadEntry.length; i++) {
threadEntry[i] = new NGramBuilder<String, NGramStats>(String.class, NGramStats.class, DMLScript.STATISTICS_NGRAM_SIZES[i], s -> s, NGramStats::merge);
}
return threadEntry;
});

for (int i = 0; i < tmp.length; i++)
tmp[i].append(instName, new NGramStats(1, timeNanos, 0));
}

public static NGramBuilder<String, NGramStats>[] mergeNGrams() {
NGramBuilder<String, NGramStats>[] builders = new NGramBuilder[DMLScript.STATISTICS_NGRAM_SIZES.length];

for (int i = 0; i < builders.length; i++) {
builders[i] = new NGramBuilder<String, NGramStats>(String.class, NGramStats.class, DMLScript.STATISTICS_NGRAM_SIZES[i], s -> s, NGramStats::merge);
}

for (int i = 0; i < DMLScript.STATISTICS_NGRAM_SIZES.length; i++) {
for (Map.Entry<String, NGramBuilder<String, NGramStats>[]> entry : _instStatsNGram.entrySet()) {
NGramBuilder<String, NGramStats> mbuilder = entry.getValue()[i];
builders[i].merge(mbuilder);
}
}

return builders;
}

public static String getNGramStdDevs(NGramStats[] stats, int offset, int prec, boolean displayZero) {
StringBuilder sb = new StringBuilder();
sb.append("(");
boolean containsData = false;
int actualIndex;
for (int i = 0; i < stats.length; i++) {
if (i != 0)
sb.append(", ");
actualIndex = (offset + i) % stats.length;
double var = 1000000000d * stats[actualIndex].n * Math.sqrt(stats[actualIndex].getTimeVariance()) / stats[actualIndex].cumTimeNanos;
if (displayZero || var >= Math.pow(10, -prec)) {
sb.append(String.format(Locale.US, "%." + prec + "f", var));
containsData = true;
}
}
sb.append(")");
return containsData ? sb.toString() : "-";
}

public static String getNGramAvgTimes(NGramStats[] stats, int offset, int prec) {
StringBuilder sb = new StringBuilder();
sb.append("(");
int actualIndex;
for (int i = 0; i < stats.length; i++) {
if (i != 0)
sb.append(", ");
actualIndex = (offset + i) % stats.length;
double var = (stats[actualIndex].cumTimeNanos / 1000000000d) / stats[actualIndex].n;
sb.append(String.format(Locale.US, "%." + prec + "f", var));
}
sb.append(")");
return sb.toString();
}

public static String nGramToCSV(final NGramBuilder<String, NGramStats> mbuilder) {
ArrayList<String> colList = new ArrayList<>();
colList.add("N-Gram");
colList.add("Time[s]");

for (int j = 0; j < mbuilder.getSize(); j++)
colList.add("Col" + (j + 1));
for (int j = 0; j < mbuilder.getSize(); j++)
colList.add("Col" + (j + 1) + "::Mean(Time[s])");
for (int j = 0; j < mbuilder.getSize(); j++)
colList.add("Col" + (j + 1) + "::StdDev(Time[s])/Col" + (j + 1) + "::Mean(Time[s])");

colList.add("Count");

return NGramBuilder.toCSV(colList.toArray(new String[colList.size()]), mbuilder.getTopK(100000, Statistics.NGramStats.getComparator(), true), e -> {
StringBuilder builder = new StringBuilder();
builder.append(e.getIdentifier().replace("(", "").replace(")", "").replace(", ", ","));
builder.append(",");
builder.append(Statistics.getNGramAvgTimes(e.getStats(), e.getOffset(), 9).replace("-", "").replace("(", "").replace(")", ""));
builder.append(",");
String stdDevs = Statistics.getNGramStdDevs(e.getStats(), e.getOffset(), 9, true).replace("-", "").replace("(", "").replace(")", "");
if (stdDevs.isEmpty()) {
for (int j = 0; j < mbuilder.getSize()-1; j++)
builder.append(",");
} else {
builder.append(stdDevs);
}
return builder.toString();
});
}

public static String getCommonNGrams(NGramBuilder<String, NGramStats> builder, int num) {
if (num <= 0 || _instStatsNGram.size() <= 0)
return "-";

//NGramBuilder<String, Long> builder = mergeNGrams();
@SuppressWarnings("unchecked")
NGramBuilder.NGramEntry<String, NGramStats>[] topNGrams = builder.getTopK(num, NGramStats.getComparator(), true).toArray(NGramBuilder.NGramEntry[]::new);

final String numCol = "#";
final String instCol = "N-Gram";
final String timeSCol = "Time(s)";
final String timeSVar = "StdDev(t)/Mean(t)";
final String countCol = "Count";
StringBuilder sb = new StringBuilder();
int len = topNGrams.length;
int numHittersToDisplay = Math.min(num, len);
int maxNumLen = String.valueOf(numHittersToDisplay).length();
int maxInstLen = instCol.length();
int maxTimeSLen = timeSCol.length();
int maxTimeSVarLen = timeSVar.length();
int maxCountLen = countCol.length();
DecimalFormat sFormat = new DecimalFormat("#,##0.000");

for (int i = 0; i < numHittersToDisplay; i++) {
long timeNs = topNGrams[i].getCumStats().cumTimeNanos;
String instruction = topNGrams[i].getIdentifier();
double timeS = timeNs / 1000000000d;


maxInstLen = Math.max(maxInstLen, instruction.length() + 1);

String timeSString = sFormat.format(timeS);
String timeSVarString = getNGramStdDevs(topNGrams[i].getStats(), topNGrams[i].getOffset(), 3, false);
maxTimeSLen = Math.max(maxTimeSLen, timeSString.length());
maxTimeSVarLen = Math.max(maxTimeSVarLen, timeSVarString.length());

maxCountLen = Math.max(maxCountLen, String.valueOf(topNGrams[i].getOccurrences()).length());
}

maxInstLen = Math.min(maxInstLen, DMLScript.STATISTICS_MAX_WRAP_LEN);
sb.append(String.format( " %" + maxNumLen + "s %-" + maxInstLen + "s %"
+ maxTimeSLen + "s %" + maxTimeSVarLen + "s %" + maxCountLen + "s", numCol, instCol, timeSCol, timeSVar, countCol));
sb.append("\n");
for (int i = 0; i < numHittersToDisplay; i++) {
String instruction = topNGrams[i].getIdentifier();
String [] wrappedInstruction = wrap(instruction, maxInstLen);

//long timeNs = tmp[len - 1 - i].getValue().time.longValue();
double timeS = topNGrams[i].getCumStats().cumTimeNanos / 1000000000d;
double timeVar = topNGrams[i].getCumStats().getTimeVariance();
String timeSString = sFormat.format(timeS);
String timeVarString = getNGramStdDevs(topNGrams[i].getStats(), topNGrams[i].getOffset(), 3, false);//sFormat.format(timeVar);

long count = topNGrams[i].getOccurrences();
int numLines = wrappedInstruction.length;

for(int wrapIter = 0; wrapIter < numLines; wrapIter++) {
String instStr = (wrapIter < wrappedInstruction.length) ? wrappedInstruction[wrapIter] : "";
if(wrapIter == 0) {
// Display instruction count
sb.append(String.format(
" %" + maxNumLen + "d %-" + maxInstLen + "s %" + maxTimeSLen + "s %" + maxTimeSVarLen + "s %" + maxCountLen + "d",
(i + 1), instStr, timeSString, timeVarString, count));
}
else {
sb.append(String.format(
" %" + maxNumLen + "s %-" + maxInstLen + "s %" + maxTimeSLen + "s %" + maxTimeSVarLen + "s %" + maxCountLen + "s",
"", instStr, "", "", ""));
}
sb.append("\n");
}
}

return sb.toString();
}

public static void maintainCPFuncCallStats(String instName) {
InstStats tmp = _instStats.get(instName);
Expand Down Expand Up @@ -679,6 +897,13 @@ public static String display(int maxHeavyHitters)
sb.append("Heavy hitter instructions:\n" + getHeavyHitters(maxHeavyHitters));
}

if (DMLScript.STATISTICS_NGRAMS) {
NGramBuilder<String, NGramStats>[] mergedNGrams = mergeNGrams();
for (int i = 0; i < DMLScript.STATISTICS_NGRAM_SIZES.length; i++) {
sb.append("Most common " + DMLScript.STATISTICS_NGRAM_SIZES[i] + "-grams (sorted by absolute time):\n" + getCommonNGrams(mergedNGrams[i], DMLScript.STATISTICS_TOP_K_NGRAMS));
}
}

if(DMLScript.FED_STATISTICS) {
sb.append("\n");
sb.append(FederatedStatistics.displayStatistics(DMLScript.FED_STATISTICS_COUNT));
Expand Down
Loading

0 comments on commit 9f96718

Please sign in to comment.