diff --git a/pmml-evaluator/src/main/java/org/jpmml/evaluator/NormalizationUtil.java b/pmml-evaluator/src/main/java/org/jpmml/evaluator/NormalizationUtil.java index f825ce32..3843b4b0 100644 --- a/pmml-evaluator/src/main/java/org/jpmml/evaluator/NormalizationUtil.java +++ b/pmml-evaluator/src/main/java/org/jpmml/evaluator/NormalizationUtil.java @@ -58,7 +58,7 @@ public Value normalize(NormContinuous normContinuous, Valu LinearNorm start; LinearNorm end; - int index = search(linearNorms, LinearNorm::requireOrig, value); + int index = binarySearch(linearNorms, LinearNorm::requireOrig, value); if(index < 0 || index == (linearNorms.size() - 1)){ OutlierTreatmentMethod outlierTreatmentMethod = normContinuous.getOutliers(); @@ -122,7 +122,7 @@ public Value denormalize(NormContinuous normContinuous, Va LinearNorm start; LinearNorm end; - int index = search(linearNorms, LinearNorm::requireNorm, value); + int index = binarySearch(linearNorms, LinearNorm::requireNorm, value); if(index < 0 || index == (linearNorms.size() - 1)){ throw new NotImplementedException(); } else @@ -136,17 +136,21 @@ public Value denormalize(NormContinuous normContinuous, Va } static - int search(List linearNorms, Function thresholdFunction, Value value){ + private int binarySearch(List linearNorms, Function thresholdFunction, Value value){ + int low = 0; + int high = linearNorms.size() - 1; - for(int i = 0, max = linearNorms.size(); i < max; i++){ - LinearNorm linearNorm = linearNorms.get(i); + while(low <= high){ + int mid = low + (high - low) / 2; + + LinearNorm linearNorm = linearNorms.get(mid); Number threshold = thresholdFunction.apply(linearNorm); if(value.compareTo(threshold) >= 0){ - if(i < (max - 1)){ - LinearNorm nextLinearNorm = linearNorms.get(i + 1); + if(mid < (linearNorms.size() - 1)){ + LinearNorm nextLinearNorm = linearNorms.get(mid + 1); Number nextThreshold = thresholdFunction.apply(nextLinearNorm); @@ -154,24 +158,24 @@ int search(List linearNorms, Function linearNorms = new ArrayList<>(); - - linearNorms.add(new LinearNorm(0d, null)); - linearNorms.add(new LinearNorm(1d, null)); - - assertEquals(-1, search(linearNorms, -1d)); - assertEquals(0, search(linearNorms, 0d)); - assertEquals(0, search(linearNorms, 1d)); - assertEquals(1, search(linearNorms, 2d)); - - linearNorms.add(new LinearNorm(2d, null)); - - assertEquals(-1, search(linearNorms, -1d)); - assertEquals(0, search(linearNorms, 1d)); - assertEquals(1, search(linearNorms, 2d)); - assertEquals(2, search(linearNorms, 3d)); - - linearNorms.add(new LinearNorm(3d, null)); - - assertEquals(-1, search(linearNorms,-1d)); - assertEquals(1, search(linearNorms, 2d)); - assertEquals(2, search(linearNorms, 3d)); - assertEquals(3, search(linearNorms, 4d)); - } - static private Double normalize(NormContinuous normContinuous, double value){ return (Double)NormalizationUtil.normalize(normContinuous, value); @@ -146,11 +116,6 @@ private Double denormalize(NormContinuous normContinuous, double value){ return (Double)NormalizationUtil.denormalize(normContinuous, value); } - static - private int search(List linearNorms, double value){ - return NormalizationUtil.search(linearNorms, LinearNorm::requireOrig, new DoubleValue(value)); - } - static private double interpolate(double x, double[] begin, double[] end){ return begin[1] + (x - begin[0]) / (end[0] - begin[0]) * (end[1] - begin[1]);