From b92a390a0b84f980d94d35d6638c0597b9797c2d Mon Sep 17 00:00:00 2001 From: Frederic Zoepffel Date: Wed, 28 Aug 2024 14:51:59 +0200 Subject: [PATCH] adjusted pruning to be done in getPairedCandidates --- scripts/builtin/incSliceLine.dml | 59 ++++++++++--------- .../part2/BuiltinIncSliceLineTest.java | 31 ++++++++-- .../functions/builtin/incSliceLineFull.dml | 4 +- 3 files changed, 59 insertions(+), 35 deletions(-) diff --git a/scripts/builtin/incSliceLine.dml b/scripts/builtin/incSliceLine.dml index 18a88d7fb80..b11e0eeb43d 100644 --- a/scripts/builtin/incSliceLine.dml +++ b/scripts/builtin/incSliceLine.dml @@ -77,10 +77,11 @@ m_incSliceLine = function( Matrix[Double] indicesRemoved = matrix(0,0,0), Boolean verbose = FALSE, list[unknown] params = list(), Matrix[Double] prevFoffb = matrix(0,0,0), Matrix[Double] prevFoffe = matrix(0,0,0), + Boolean disableIncSizePruning = FALSE, Boolean disableIncScorePruning = FALSE, list[unknown] prevLattice = list(), list[unknown] metaPrevLattice = list(), list[unknown] prevStats = list(), Matrix[Double] prevTK = matrix(0,0,0), - Matrix[Double] prevTKC = matrix(0,0,0), Boolean encodeLat = TRUE, - Boolean disableIncSizePruning = FALSE, Boolean disableIncScorePruning = FALSE) + Matrix[Double] prevTKC = matrix(0,0,0), Boolean encodeLat = TRUE + ) return( Matrix[Double] TK, Matrix[Double] TKC, Matrix[Double] D, list[unknown] L, list[unknown] metaLattice, @@ -211,7 +212,6 @@ m_incSliceLine = function( # reduce dataset to relevant attributes (minSup, err>0), S reduced on-the-fly if( selFeat ){ X2 = removeEmpty(target=X2, margin="cols", select=t(selCols)); - changedX2 = removeEmpty(target=changedX2, margin="cols", select=t(selCols)); } # lattice enumeration w/ size/error pruning, one iteration per level @@ -221,9 +221,16 @@ m_incSliceLine = function( while( nrow(S) > 0 & sum(S) > 0 & level < n & level < maxL ) { level = level + 1; + # load one hot encoded previous lattice for the current level + prevLattice2 = matrix(0,0,0); + if(!disableIncSizePruning){ + prevLattice2 = preparePrevLattice(prevLattice, metaPrevLattice, prevFoffb, + prevFoffe, foffb, foffe, level, encodeLat, differentOffsets) + } + # enumerate candidate join pairs, incl size/error pruning nrS = nrow(S); - [S, minsc] = getPairedCandidates(S, minsc, R, TKC, level, eAvg, minSup, alpha, n2, foffb, foffe); + [S, minsc] = getPairedCandidates(S, minsc, R, TKC, level, eAvg, minSup, alpha, n2, foffb, foffe, prevLattice2, prevStats, changedX2, verbose, disableIncSizePruning) S2 = S; # prepare and store output lattice for next run @@ -235,29 +242,14 @@ m_incSliceLine = function( } L = append(L, Lrep); - # load one hot encoded previous lattice for the current level - prevLattice2 = matrix(0,0,0); - if(!disableIncSizePruning){ - prevLattice2 = preparePrevLattice(prevLattice, metaPrevLattice, prevFoffb, - prevFoffe, foffb, foffe, level, encodeLat, differentOffsets) - } - if(selFeat){ - if(length(prevLattice2)>0 & !disableIncSizePruning){ - prevLattice2 = removeEmpty(target=prevLattice2, margin="cols", select=t(selCols)); - } S2 = removeEmpty(target=S, margin="cols", select=t(selCols)); } if(verbose) { print("\nincSliceLine: level "+level+":") } - - # prune unchanged slices with slice size < minSup - if(level <= length(prevStats) & !disableIncSizePruning){ - [S, S2] = pruneUnchangedSlices(S, S2, prevLattice2, prevStats, changedX2, minSup, verbose, level); - } if(verbose) { print(" -- generated paired slice candidates: "+nrS+" -> "+nrow(S)); @@ -440,7 +432,7 @@ analyzeTopK = function(Matrix[Double] TKC) return(Double maxsc, Double minsc) { getPairedCandidates = function(Matrix[Double] S, Double minsc, Matrix[Double] R, Matrix[Double] TKC, Integer level, Double eAvg, Integer minSup, Double alpha, Integer n2, - Matrix[Double] foffb, Matrix[Double] foffe) + Matrix[Double] foffb, Matrix[Double] foffe, Matrix[Double] prevLattice2, list[unknown] prevStats, Matrix[Double] changedX2, Boolean verbose, Boolean disableIncSizePruning) return(Matrix[Double] P, Double minsc) { # prune invalid slices (possible without affecting overall @@ -471,6 +463,21 @@ getPairedCandidates = function(Matrix[Double] S, Double minsc, sm = min(P1 %*% R[,3], P2 %*% R[,3]) ss = min(P1 %*% R[,4], P2 %*% R[,4]) + # prune unchanged slices with slice size < minSup + if(level <= length(prevStats) & !disableIncSizePruning){ + I = pruneUnchangedSlices(P, prevLattice2, prevStats, changedX2, minSup, verbose, level); + if(sum(I) > 0){ + P = removeEmpty(target=P, margin="rows", select=I == 0); + P12 = removeEmpty(target=P12, margin="rows", select=I == 0); + ss = removeEmpty(target=ss, margin="rows", select=I == 0); + se = removeEmpty(target=se, margin="rows", select=I == 0); + sm = removeEmpty(target=sm, margin="rows", select=I == 0); + if(verbose) { + print(" -- Pruning " + sum(I) +" slices that are unchanged and below min sup."); + } + } + } + # prune invalid self joins (>1 bit per feature) I = matrix(1, nrow(P), 1); for( j in 1:ncol(foffb) ) { @@ -661,8 +668,8 @@ computeLowestPrevTK = function(Matrix[Double] prevTK2, Matrix[Double] X2,Matrix[ } } -pruneUnchangedSlices = function(Matrix[Double] S, Matrix[Double] S2, Matrix[Double] prevLattice2, list[unknown] prevStats, Matrix[Double] changedX2, Int minSup, Boolean verbose, Integer level) - return(Matrix[Double] S, Matrix[Double] S2) +pruneUnchangedSlices = function(Matrix[Double] S2, Matrix[Double] prevLattice2, list[unknown] prevStats, Matrix[Double] changedX2, Int minSup, Boolean verbose, Integer level) + return(Matrix[Double] unchangedAndBelowMinSupI) { unchangedS = matrix(0,0,ncol(prevLattice2)); unchangedR = matrix(0,0,4); @@ -687,13 +694,7 @@ pruneUnchangedSlices = function(Matrix[Double] S, Matrix[Double] S2, Matrix[Doub } } - if(sum(unchangedAndBelowMinSupI) > 0){ - S2 = removeEmpty(target=S2, margin="rows", select=unchangedAndBelowMinSupI == 0); - S = removeEmpty(target=S, margin="rows", select=unchangedAndBelowMinSupI == 0); - if(verbose) { - print(" -- Pruning " + sum(unchangedAndBelowMinSupI) +" slices that are unchanged and below min sup."); - } - } + } } diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinIncSliceLineTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinIncSliceLineTest.java index a115767fe34..81bb30a52fc 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinIncSliceLineTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinIncSliceLineTest.java @@ -22,6 +22,8 @@ import org.junit.Assert; import org.junit.Test; +import static org.junit.Assert.fail; + import java.util.HashMap; import org.apache.sysds.common.Types.ExecMode; @@ -217,6 +219,16 @@ public void testTop4HybridTPSelFullFewAdded() { runIncSliceLineTest(4, "e", false, true,2, 1, false, false, false, ExecMode.HYBRID); } + @Test + public void testTop4HybridTPSelFullFewAddedDisabledScore() { + runIncSliceLineTest(4, "e", false, true,2, 1, false, false, false, ExecMode.HYBRID, true, false); + } + + @Test + public void testTop4HybridTPSelFullFewAddedDisabledSize() { + runIncSliceLineTest(4, "e", false, true,2, 1, false, false, false, ExecMode.HYBRID, false, true); + } + @Test public void testTop4HybridDPSelFullFewAddedRemoved() { runIncSliceLineTest(4, "e", true, true,2, 1, false, true, false, ExecMode.HYBRID); @@ -982,7 +994,7 @@ public void testIncSliceLineCustomInputsFull() { }; - runIncSliceLineTest(newX, e, 10, "e", false, true, 50, 1, false, false, true, ExecMode.SINGLE_NODE); + runIncSliceLineTest(newX, e, 10, "e", false, true, 50, 1, false, false, true, ExecMode.SINGLE_NODE, false, false); } // @Test @@ -1050,12 +1062,18 @@ private void runIncSliceLineTest(int K, String err, boolean dp, boolean selCols, private void runIncSliceLineTest(int K, String err, boolean dp, boolean selCols, int proportionOfTuplesAddedInPercent, int proportionOfTuplesRemovedInPercent, boolean onlyNullEAdded, boolean removeTuples, boolean encodeLat, ExecMode mode) { - runIncSliceLineTest(null, null, K, err, dp, selCols, proportionOfTuplesAddedInPercent, proportionOfTuplesRemovedInPercent, onlyNullEAdded, removeTuples, encodeLat, mode); + runIncSliceLineTest(null, null, K, err, dp, selCols, proportionOfTuplesAddedInPercent, proportionOfTuplesRemovedInPercent, onlyNullEAdded, removeTuples, encodeLat, mode, false, false); + + } + + private void runIncSliceLineTest(int K, String err, boolean dp, boolean selCols, int proportionOfTuplesAddedInPercent, int proportionOfTuplesRemovedInPercent, boolean onlyNullEAdded, boolean removeTuples, boolean encodeLat, ExecMode mode, boolean disableScore, boolean disableSize) { + + runIncSliceLineTest(null, null, K, err, dp, selCols, proportionOfTuplesAddedInPercent, proportionOfTuplesRemovedInPercent, onlyNullEAdded, removeTuples, encodeLat, mode, disableScore, disableSize); } - private void runIncSliceLineTest(double[][] customX, double[][] customE,int K, String err, boolean dp, boolean selCols, int proportionOfTuplesAddedInPercent, int proportionOfTuplesRemovedInPercent, boolean onlyNullEAdded, boolean removeTuples, boolean encodeLat, ExecMode mode) { + private void runIncSliceLineTest(double[][] customX, double[][] customE,int K, String err, boolean dp, boolean selCols, int proportionOfTuplesAddedInPercent, int proportionOfTuplesRemovedInPercent, boolean onlyNullEAdded, boolean removeTuples, boolean encodeLat, ExecMode mode, boolean disableScore, boolean disableSize) { ExecMode platformOld = setExecMode(mode); loadTestConfiguration(getTestConfiguration(TEST_NAME2)); @@ -1135,7 +1153,7 @@ private void runIncSliceLineTest(double[][] customX, double[][] customE,int K, S fullDMLScriptName = HOME + TEST_NAME2 + ".dml"; programArgs = new String[] { "-args", input("addedX"), input("oldX"), input("oldE"), input("addedE"), String.valueOf(K), String.valueOf(!dp).toUpperCase(), String.valueOf(selCols).toUpperCase(), String.valueOf(encodeLat).toUpperCase(), input("indicesRemoved"), - String.valueOf(VERBOSE).toUpperCase(), output("R1"), output("R2") }; + String.valueOf(VERBOSE).toUpperCase(), output("R1"), output("R2"), String.valueOf(disableScore).toUpperCase(), String.valueOf(disableSize).toUpperCase() }; runTest(true, false, null, -1); @@ -1267,6 +1285,9 @@ public void testIncSliceLineCustomInputsFull(double[][] addedX, double[][] oldX, double[][] indicesRemoved = new double[1][1]; indicesRemoved[0][0] = 0; + + boolean disableScore = false; + boolean disableSize = false; writeInputMatrixWithMTD("addedX", addedX, false); @@ -1278,7 +1299,7 @@ public void testIncSliceLineCustomInputsFull(double[][] addedX, double[][] oldX, fullDMLScriptName = HOME + TEST_NAME2 + ".dml"; programArgs = new String[] { "-args", input("addedX"), input("oldX"), input("oldE"), input("addedE"), String.valueOf(K), String.valueOf(!dp).toUpperCase(), String.valueOf(selCols).toUpperCase(), String.valueOf(encodeLat).toUpperCase(), input("indicesRemoved"), - String.valueOf(VERBOSE).toUpperCase(), output("R1"), output("R2") }; + String.valueOf(VERBOSE).toUpperCase(), output("R1"), output("R2"), String.valueOf(disableScore).toUpperCase(), String.valueOf(disableSize).toUpperCase() }; runTest(true, false, null, -1); diff --git a/src/test/scripts/functions/builtin/incSliceLineFull.dml b/src/test/scripts/functions/builtin/incSliceLineFull.dml index d9dbafc7e0e..aeab7b8f681 100644 --- a/src/test/scripts/functions/builtin/incSliceLineFull.dml +++ b/src/test/scripts/functions/builtin/incSliceLineFull.dml @@ -26,6 +26,8 @@ oldE = read($3); addedE = read($4); totalE = rbind(oldE, addedE); indicesRemoved = read($9); +disableIncScorePruning = $13; +disableIncSizePruning = $14; if(nrow(indicesRemoved) > 0){ if(as.scalar(indicesRemoved[1]) == 0){ @@ -40,7 +42,7 @@ if(nrow(indicesRemoved) > 0){ # second increment [TK1, TKC1, D1, L1, meta1, Stats1, Xout1, eOut1, foffb2, foffe2, params] = incSliceLine(addedX=addedX, oldX = oldX, oldE = oldE, addedE=addedE, prevLattice = L, metaPrevLattice=meta, prevStats = Stats, prevTK = TK, prevTKC = TKC, k=$5, - alpha=0.95, minSup=4, tpEval=$6, selFeat=$7, encodeLat=$8, indicesRemoved=indicesRemoved, verbose=$10, params=params, prevFoffb = foffb, prevFoffe = foffe); + alpha=0.95, minSup=4, tpEval=$6, selFeat=$7, encodeLat=$8, indicesRemoved=indicesRemoved, verbose=$10, params=params, prevFoffb = foffb, prevFoffe = foffe, disableIncSizePruning = disableIncSizePruning, disableIncScorePruning = disableIncScorePruning); # prepare totalX and totalE for running sliceline on total data if(nrow(indicesRemoved) > 0){