From 95e62ebe51bbe548708b6dcdba957746f087d0d3 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 27 Sep 2018 14:58:01 -0400 Subject: [PATCH] Swap Koloboke for Eclipse Collections; see #304 --- LICENSE | 18 ++++ app/oryx-app-common/pom.xml | 9 +- .../oryx/app/als/FeatureVectorsPartition.java | 17 ++-- .../app/als/PartitionedFeatureVectors.java | 13 ++- .../oryx/app/batch/mllib/rdf/RDFUpdate.java | 64 ++++++------- .../oryx/app/serving/als/Recommend.java | 5 +- .../app/serving/als/RecommendToAnonymous.java | 5 +- .../oryx/app/serving/als/RecommendToMany.java | 5 +- .../app/serving/als/RecommendWithContext.java | 5 +- .../oryx/app/serving/als/Similarity.java | 5 +- .../serving/als/model/ALSServingModel.java | 93 +++++++++---------- .../app/serving/als/model/TopNConsumer.java | 9 +- .../oryx/app/speed/als/ALSSpeedModel.java | 12 +-- .../oryx/common/collection/package-info.java | 3 +- .../lang/ToDoubleObjDoubleBiFunction.java | 36 +++++++ pom.xml | 27 ++---- 16 files changed, 170 insertions(+), 156 deletions(-) create mode 100644 framework/oryx-common/src/main/java/com/cloudera/oryx/common/lang/ToDoubleObjDoubleBiFunction.java diff --git a/LICENSE b/LICENSE index 3c6cbcc91..5ec149d24 100644 --- a/LICENSE +++ b/LICENSE @@ -182,3 +182,21 @@ Redistribution and use in source and binary forms, with or without modification, 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +======================================================================== +Eclipse Collections (https://github.com/eclipse/eclipse-collections) +======================================================================== + +Eclipse Distribution License - v 1.0 + +Copyright (c) 2007, Eclipse Foundation, Inc. and its licensors. + +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + + Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + Neither the name of the Eclipse Foundation, Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/app/oryx-app-common/pom.xml b/app/oryx-app-common/pom.xml index a9be91a3e..8fabef9ad 100644 --- a/app/oryx-app-common/pom.xml +++ b/app/oryx-app-common/pom.xml @@ -47,13 +47,12 @@ hadoop-client - com.koloboke - koloboke-api-jdk8 + org.eclipse.collections + eclipse-collections-api - com.koloboke - koloboke-impl-jdk8 - runtime + org.eclipse.collections + eclipse-collections diff --git a/app/oryx-app-common/src/main/java/com/cloudera/oryx/app/als/FeatureVectorsPartition.java b/app/oryx-app-common/src/main/java/com/cloudera/oryx/app/als/FeatureVectorsPartition.java index 03834358b..0ea4f8ab4 100644 --- a/app/oryx-app-common/src/main/java/com/cloudera/oryx/app/als/FeatureVectorsPartition.java +++ b/app/oryx-app-common/src/main/java/com/cloudera/oryx/app/als/FeatureVectorsPartition.java @@ -18,10 +18,9 @@ import java.util.Collection; import java.util.function.BiConsumer; -import com.koloboke.collect.map.ObjObjMap; -import com.koloboke.collect.map.hash.HashObjObjMaps; -import com.koloboke.collect.set.ObjSet; -import com.koloboke.collect.set.hash.HashObjSets; +import org.eclipse.collections.api.set.MutableSet; +import org.eclipse.collections.impl.map.mutable.UnifiedMap; +import org.eclipse.collections.impl.set.mutable.UnifiedSet; import com.cloudera.oryx.common.lang.AutoLock; import com.cloudera.oryx.common.lang.AutoReadWriteLock; @@ -35,13 +34,13 @@ */ public final class FeatureVectorsPartition implements FeatureVectors { - private final ObjObjMap vectors; - private final ObjSet recentIDs; + private final UnifiedMap vectors; + private final MutableSet recentIDs; private final AutoReadWriteLock lock; public FeatureVectorsPartition() { - vectors = HashObjObjMaps.newMutableMap(); - recentIDs = HashObjSets.newMutableSet(); + vectors = UnifiedMap.newMap(); + recentIDs = UnifiedSet.newSet(); lock = new AutoReadWriteLock(); } @@ -101,7 +100,7 @@ public void addAllRecentTo(Collection allRecent) { @Override public void retainRecentAndIDs(Collection newModelIDs) { try (AutoLock al = lock.autoWriteLock()) { - vectors.removeIf((key, value) -> !newModelIDs.contains(key) && !recentIDs.contains(key)); + vectors.keySet().removeIf(key -> !newModelIDs.contains(key) && !recentIDs.contains(key)); recentIDs.clear(); } } diff --git a/app/oryx-app-common/src/main/java/com/cloudera/oryx/app/als/PartitionedFeatureVectors.java b/app/oryx-app-common/src/main/java/com/cloudera/oryx/app/als/PartitionedFeatureVectors.java index 865da0c02..868d379aa 100644 --- a/app/oryx-app-common/src/main/java/com/cloudera/oryx/app/als/PartitionedFeatureVectors.java +++ b/app/oryx-app-common/src/main/java/com/cloudera/oryx/app/als/PartitionedFeatureVectors.java @@ -27,8 +27,7 @@ import java.util.stream.Stream; import com.google.common.base.Preconditions; -import com.koloboke.collect.map.ObjIntMap; -import com.koloboke.collect.map.hash.HashObjIntMaps; +import org.eclipse.collections.impl.map.mutable.primitive.ObjectIntHashMap; import com.cloudera.oryx.common.lang.AutoLock; import com.cloudera.oryx.common.lang.AutoReadWriteLock; @@ -45,7 +44,7 @@ public final class PartitionedFeatureVectors implements FeatureVectors { private final FeatureVectorsPartition[] partitions; private final ToIntBiFunction partitioner; /** Maps item IDs to their existing partition, if any */ - private final ObjIntMap partitionMap; + private final ObjectIntHashMap partitionMap; /** Controls access to yPartitionMap. */ private final AutoReadWriteLock partitionMapLock; private final ExecutorService executor; @@ -66,7 +65,7 @@ public PartitionedFeatureVectors(int numPartitions, for (int i = 0; i < numPartitions; i++) { partitions[i] = new FeatureVectorsPartition(); } - partitionMap = HashObjIntMaps.newMutableMap(); + partitionMap = ObjectIntHashMap.newMap(); partitionMapLock = new AutoReadWriteLock(); this.partitioner = partitioner; this.executor = executor; @@ -152,7 +151,7 @@ public float[] getVector(String item) { // the partitioner might change its answer over time. int partition; try (AutoLock al = partitionMapLock.autoReadLock()) { - partition = partitionMap.getOrDefault(item, Integer.MIN_VALUE); + partition = partitionMap.getIfAbsent(item, Integer.MIN_VALUE); } if (partition < 0) { return null; @@ -165,7 +164,7 @@ public void setVector(String item, float[] vector) { int newPartition = partitioner.applyAsInt(item, vector); // Exclusive update to mapping -- careful since other locks are acquired inside here try (AutoLock al = partitionMapLock.autoWriteLock()) { - int existingPartition = partitionMap.getOrDefault(item, Integer.MIN_VALUE); + int existingPartition = partitionMap.getIfAbsent(item, Integer.MIN_VALUE); if (existingPartition >= 0 && existingPartition != newPartition) { // Move from one to the other partition, so first remove old entry partitions[existingPartition].removeVector(item); @@ -221,7 +220,7 @@ public double[] getVTV(boolean background) { @Override public String toString() { int maxSize = 16; - List partitionSizes = new ArrayList<>(maxSize); + Collection partitionSizes = new ArrayList<>(maxSize); for (int i = 0; i < partitions.length; i++) { int size = partitions[i].size(); if (size > 0) { diff --git a/app/oryx-app-mllib/src/main/java/com/cloudera/oryx/app/batch/mllib/rdf/RDFUpdate.java b/app/oryx-app-mllib/src/main/java/com/cloudera/oryx/app/batch/mllib/rdf/RDFUpdate.java index 414aadf8e..996b9148c 100644 --- a/app/oryx-app-mllib/src/main/java/com/cloudera/oryx/app/batch/mllib/rdf/RDFUpdate.java +++ b/app/oryx-app-mllib/src/main/java/com/cloudera/oryx/app/batch/mllib/rdf/RDFUpdate.java @@ -30,8 +30,6 @@ import java.util.stream.IntStream; import com.google.common.base.Preconditions; -import com.koloboke.collect.map.IntLongMap; -import com.koloboke.collect.map.hash.HashIntLongMaps; import com.typesafe.config.Config; import org.apache.hadoop.fs.Path; import org.apache.spark.api.java.JavaRDD; @@ -61,6 +59,8 @@ import org.dmg.pmml.mining.Segmentation; import org.dmg.pmml.tree.Node; import org.dmg.pmml.tree.TreeModel; +import org.eclipse.collections.api.map.primitive.IntLongMap; +import org.eclipse.collections.impl.map.mutable.primitive.IntLongHashMap; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import scala.collection.JavaConversions; @@ -162,8 +162,8 @@ public PMML buildModel(JavaSparkContext sparkContext, seed); } - List> treeNodeIDCounts = treeNodeExampleCounts(trainPointData, model); - Map predictorIndexCounts = predictorExampleCounts(trainPointData, model); + List treeNodeIDCounts = treeNodeExampleCounts(trainPointData, model); + IntLongHashMap predictorIndexCounts = predictorExampleCounts(trainPointData, model); return rdfModelToPMML(model, categoricalValueEncodings, @@ -264,31 +264,30 @@ private JavaRDD parseToLabeledPointRDD( * per tree in the model * @see #predictorExampleCounts(JavaRDD,RandomForestModel) */ - private static List> treeNodeExampleCounts(JavaRDD trainPointData, - RandomForestModel model) { + private static List treeNodeExampleCounts(JavaRDD trainPointData, + RandomForestModel model) { return trainPointData.mapPartitions(data -> { DecisionTreeModel[] trees = model.trees(); - List treeNodeIDCounts = IntStream.range(0, trees.length). - mapToObj(i -> HashIntLongMaps.newMutableMap()).collect(Collectors.toList()); + List treeNodeIDCounts = IntStream.range(0, trees.length). + mapToObj(i -> new IntLongHashMap()).collect(Collectors.toList()); data.forEachRemaining(datum -> { double[] featureVector = datum.features().toArray(); for (int i = 0; i < trees.length; i++) { DecisionTreeModel tree = trees[i]; - IntLongMap nodeIDCount = treeNodeIDCounts.get(i); + IntLongHashMap nodeIDCount = treeNodeIDCounts.get(i); org.apache.spark.mllib.tree.model.Node node = tree.topNode(); // This logic cloned from Node.predict: while (!node.isLeaf()) { // Count node ID - nodeIDCount.addValue(node.id(), 1); + nodeIDCount.addToValue(node.id(), 1); Split split = node.split().get(); int featureIndex = split.feature(); node = nextNode(featureVector, node, split, featureIndex); } - nodeIDCount.addValue(node.id(), 1); + nodeIDCount.addToValue(node.id(), 1); } }); - return Collections.>>singleton( - treeNodeIDCounts.stream().map(HashMap::new).collect(Collectors.toList())).iterator(); + return Collections.singleton(treeNodeIDCounts).iterator(); } ).reduce((a, b) -> { Preconditions.checkArgument(a.size() == b.size()); @@ -307,10 +306,10 @@ private static List> treeNodeExampleCounts(JavaRDD predictorExampleCounts(JavaRDD trainPointData, - RandomForestModel model) { + private static IntLongHashMap predictorExampleCounts(JavaRDD trainPointData, + RandomForestModel model) { return trainPointData.mapPartitions(data -> { - IntLongMap featureIndexCount = HashIntLongMaps.newMutableMap(); + IntLongHashMap featureIndexCount = new IntLongHashMap(); data.forEachRemaining(datum -> { double[] featureVector = datum.features().toArray(); for (DecisionTreeModel tree : model.trees()) { @@ -320,14 +319,12 @@ private static Map predictorExampleCounts(JavaRDD>singleton( - new HashMap<>(featureIndexCount)).iterator(); + return Collections.singleton(featureIndexCount).iterator(); }).reduce(RDFUpdate::merge); } @@ -352,26 +349,21 @@ private static org.apache.spark.mllib.tree.model.Node nextNode( } } - private static Map merge(Map a, Map b) { + private static IntLongHashMap merge(IntLongHashMap a, IntLongHashMap b) { if (b.size() > a.size()) { return merge(b, a); } - b.forEach((key, value) -> a.merge(key, value, (x, y) -> x + y)); + b.forEachKeyValue(a::addToValue); return a; } - private static long get(Map map, T key) { - Long count = map.get(key); - return count == null ? 0L : count; - } - private PMML rdfModelToPMML(RandomForestModel rfModel, CategoricalValueEncodings categoricalValueEncodings, int maxDepth, int maxSplitCandidates, String impurity, - List> nodeIDCounts, - Map predictorIndexCounts) { + List nodeIDCounts, + IntLongMap predictorIndexCounts) { boolean classificationTask = rfModel.algo().equals(Algo.Classification()); Preconditions.checkState(classificationTask == inputSchema.isClassification()); @@ -422,7 +414,7 @@ private PMML rdfModelToPMML(RandomForestModel rfModel, private TreeModel toTreeModel(DecisionTreeModel dtModel, CategoricalValueEncodings categoricalValueEncodings, - Map nodeIDCounts) { + IntLongMap nodeIDCounts) { boolean classificationTask = dtModel.algo().equals(Algo.Classification()); Preconditions.checkState(classificationTask == inputSchema.isClassification()); @@ -448,7 +440,7 @@ private TreeModel toTreeModel(DecisionTreeModel dtModel, modelNode.setPredicate(predicate); org.apache.spark.mllib.tree.model.Node treeNode = treeNodePredicate.getFirst(); - long nodeCount = get(nodeIDCounts, treeNode.id()); + long nodeCount = nodeIDCounts.get(treeNode.id()); modelNode.setRecordCount((double) nodeCount); if (treeNode.isLeaf()) { @@ -492,7 +484,7 @@ private TreeModel toTreeModel(DecisionTreeModel dtModel, org.apache.spark.mllib.tree.model.Node rightTreeNode = treeNode.rightNode().get(); org.apache.spark.mllib.tree.model.Node leftTreeNode = treeNode.leftNode().get(); - boolean defaultRight = get(nodeIDCounts, rightTreeNode.id()) > get(nodeIDCounts, leftTreeNode.id()); + boolean defaultRight = nodeIDCounts.get(rightTreeNode.id()) > nodeIDCounts.get(leftTreeNode.id()); modelNode.setDefaultChild(defaultRight ? positiveModelNode.getId() : negativeModelNode.getId()); // Right node is "positive", so carries the predicate. It must evaluate first @@ -528,7 +520,7 @@ private Predicate buildPredicate(Split split, // So the predicate will evaluate "not in" this set // More ugly casting @SuppressWarnings("unchecked") - List javaCategories = (List) (List) + Collection javaCategories = (Collection) (Collection) JavaConversions.seqAsJavaList(split.categories()); Set negativeEncodings = javaCategories.stream().map(Double::intValue).collect(Collectors.toSet()); @@ -548,11 +540,11 @@ private Predicate buildPredicate(Split split, } } - private double[] countsToImportances(Map predictorIndexCounts) { + private double[] countsToImportances(IntLongMap predictorIndexCounts) { double[] importances = new double[inputSchema.getNumPredictors()]; - long total = predictorIndexCounts.values().stream().mapToLong(l -> l).sum(); + long total = predictorIndexCounts.sum(); Preconditions.checkArgument(total > 0); - predictorIndexCounts.forEach((k, count) -> importances[k] = (double) count / total); + predictorIndexCounts.forEachKeyValue((k, count) -> importances[k] = (double) count / total); return importances; } diff --git a/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/Recommend.java b/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/Recommend.java index f71be80f3..f4f9b78c3 100644 --- a/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/Recommend.java +++ b/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/Recommend.java @@ -29,14 +29,13 @@ import javax.ws.rs.QueryParam; import javax.ws.rs.core.MediaType; -import com.koloboke.function.ObjDoubleToDoubleFunction; - import com.cloudera.oryx.api.serving.OryxServingException; import com.cloudera.oryx.app.als.Rescorer; import com.cloudera.oryx.app.als.RescorerProvider; import com.cloudera.oryx.app.serving.IDValue; import com.cloudera.oryx.app.serving.als.model.ALSServingModel; import com.cloudera.oryx.common.collection.Pair; +import com.cloudera.oryx.common.lang.ToDoubleObjDoubleBiFunction; /** *

Responds to a GET request to @@ -92,7 +91,7 @@ public List get( } } - ObjDoubleToDoubleFunction rescoreFn = null; + ToDoubleObjDoubleBiFunction rescoreFn = null; RescorerProvider rescorerProvider = getALSServingModel().getRescorerProvider(); if (rescorerProvider != null) { Rescorer rescorer = rescorerProvider.getRecommendRescorer(Collections.singletonList(userID), diff --git a/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/RecommendToAnonymous.java b/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/RecommendToAnonymous.java index 1a60fe0c0..e521eab8b 100644 --- a/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/RecommendToAnonymous.java +++ b/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/RecommendToAnonymous.java @@ -31,14 +31,13 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.PathSegment; -import com.koloboke.function.ObjDoubleToDoubleFunction; - import com.cloudera.oryx.api.serving.OryxServingException; import com.cloudera.oryx.app.als.Rescorer; import com.cloudera.oryx.app.als.RescorerProvider; import com.cloudera.oryx.app.serving.IDValue; import com.cloudera.oryx.app.serving.als.model.ALSServingModel; import com.cloudera.oryx.common.collection.Pair; +import com.cloudera.oryx.common.lang.ToDoubleObjDoubleBiFunction; /** *

Responds to a GET request to @@ -79,7 +78,7 @@ public List get( Collection knownItemsSet = new HashSet<>(knownItems); Predicate allowedFn = v -> !knownItemsSet.contains(v); - ObjDoubleToDoubleFunction rescoreFn = null; + ToDoubleObjDoubleBiFunction rescoreFn = null; RescorerProvider rescorerProvider = getALSServingModel().getRescorerProvider(); if (rescorerProvider != null) { Rescorer rescorer = rescorerProvider.getRecommendToAnonymousRescorer(knownItems, diff --git a/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/RecommendToMany.java b/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/RecommendToMany.java index d2f492e49..e6c87b344 100644 --- a/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/RecommendToMany.java +++ b/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/RecommendToMany.java @@ -31,14 +31,13 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.PathSegment; -import com.koloboke.function.ObjDoubleToDoubleFunction; - import com.cloudera.oryx.api.serving.OryxServingException; import com.cloudera.oryx.app.als.Rescorer; import com.cloudera.oryx.app.als.RescorerProvider; import com.cloudera.oryx.app.serving.IDValue; import com.cloudera.oryx.app.serving.als.model.ALSServingModel; import com.cloudera.oryx.common.collection.Pair; +import com.cloudera.oryx.common.lang.ToDoubleObjDoubleBiFunction; /** *

Responds to a GET request to @@ -90,7 +89,7 @@ public List get( allowedFn = v -> !userKnownItems.contains(v); } - ObjDoubleToDoubleFunction rescoreFn = null; + ToDoubleObjDoubleBiFunction rescoreFn = null; RescorerProvider rescorerProvider = getALSServingModel().getRescorerProvider(); if (rescorerProvider != null) { Rescorer rescorer = rescorerProvider.getRecommendRescorer(userIDs, rescorerParams); diff --git a/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/RecommendWithContext.java b/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/RecommendWithContext.java index f2aabec39..2ae0f3bd6 100644 --- a/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/RecommendWithContext.java +++ b/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/RecommendWithContext.java @@ -31,14 +31,13 @@ import java.util.stream.Collectors; import java.util.stream.Stream; -import com.koloboke.function.ObjDoubleToDoubleFunction; - import com.cloudera.oryx.api.serving.OryxServingException; import com.cloudera.oryx.app.als.Rescorer; import com.cloudera.oryx.app.als.RescorerProvider; import com.cloudera.oryx.app.serving.IDValue; import com.cloudera.oryx.app.serving.als.model.ALSServingModel; import com.cloudera.oryx.common.collection.Pair; +import com.cloudera.oryx.common.lang.ToDoubleObjDoubleBiFunction; /** *

Responds to a GET request to @@ -84,7 +83,7 @@ public List get( } Predicate allowedFn = v -> !knownItems.contains(v); - ObjDoubleToDoubleFunction rescoreFn = null; + ToDoubleObjDoubleBiFunction rescoreFn = null; RescorerProvider rescorerProvider = getALSServingModel().getRescorerProvider(); if (rescorerProvider != null) { Rescorer rescorer = rescorerProvider.getRecommendRescorer(Collections.singletonList(userID), diff --git a/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/Similarity.java b/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/Similarity.java index 9fcf7ddb7..361fe33a1 100644 --- a/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/Similarity.java +++ b/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/Similarity.java @@ -30,14 +30,13 @@ import javax.ws.rs.core.MediaType; import javax.ws.rs.core.PathSegment; -import com.koloboke.function.ObjDoubleToDoubleFunction; - import com.cloudera.oryx.api.serving.OryxServingException; import com.cloudera.oryx.app.als.Rescorer; import com.cloudera.oryx.app.als.RescorerProvider; import com.cloudera.oryx.app.serving.IDValue; import com.cloudera.oryx.app.serving.als.model.ALSServingModel; import com.cloudera.oryx.common.collection.Pair; +import com.cloudera.oryx.common.lang.ToDoubleObjDoubleBiFunction; /** *

Responds to a GET request to @@ -84,7 +83,7 @@ public List get( } Predicate allowedFn = v -> !knownItems.contains(v); - ObjDoubleToDoubleFunction rescoreFn = null; + ToDoubleObjDoubleBiFunction rescoreFn = null; RescorerProvider rescorerProvider = getALSServingModel().getRescorerProvider(); if (rescorerProvider != null) { Rescorer rescorer = rescorerProvider.getMostSimilarItemsRescorer(rescorerParams); diff --git a/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/model/ALSServingModel.java b/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/model/ALSServingModel.java index e5ce114c5..be1bb2e97 100644 --- a/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/model/ALSServingModel.java +++ b/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/model/ALSServingModel.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -28,14 +29,11 @@ import com.google.common.base.Preconditions; import com.google.common.util.concurrent.ThreadFactoryBuilder; -import com.koloboke.collect.ObjCursor; -import com.koloboke.collect.map.ObjIntMap; -import com.koloboke.collect.map.ObjObjMap; -import com.koloboke.collect.map.hash.HashObjIntMaps; -import com.koloboke.collect.map.hash.HashObjObjMaps; -import com.koloboke.collect.set.ObjSet; -import com.koloboke.collect.set.hash.HashObjSets; -import com.koloboke.function.ObjDoubleToDoubleFunction; +import org.eclipse.collections.api.map.MutableMap; +import org.eclipse.collections.api.set.MutableSet; +import org.eclipse.collections.impl.map.mutable.UnifiedMap; +import org.eclipse.collections.impl.map.mutable.primitive.ObjectIntHashMap; +import org.eclipse.collections.impl.set.mutable.UnifiedSet; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -49,6 +47,7 @@ import com.cloudera.oryx.common.collection.Pairs; import com.cloudera.oryx.common.lang.AutoLock; import com.cloudera.oryx.common.lang.AutoReadWriteLock; +import com.cloudera.oryx.common.lang.ToDoubleObjDoubleBiFunction; import com.cloudera.oryx.common.math.Solver; /** @@ -68,11 +67,11 @@ public final class ALSServingModel implements ServingModel { /** Item-feature matrix. This is partitioned into several maps for parallel access. */ private final PartitionedFeatureVectors Y; /** Remembers items that each user has interacted with*/ - private final ObjObjMap> knownItems; // Right now no corresponding "knownUsers" object + private final MutableMap> knownItems; // Right now no corresponding "knownUsers" object private final AutoReadWriteLock knownItemsLock; - private final ObjSet expectedUserIDs; + private final MutableSet expectedUserIDs; private final AutoReadWriteLock expectedUserIDsLock; - private final ObjSet expectedItemIDs; + private final MutableSet expectedItemIDs; private final AutoReadWriteLock expectedItemIDsLock; private final SolverCache cachedYTYSolver; /** Number of features used in the model. */ @@ -102,12 +101,12 @@ public final class ALSServingModel implements ServingModel { executor, (String id, float[] vector) -> lsh.getIndexFor(vector)); - knownItems = HashObjObjMaps.newMutableMap(); + knownItems = UnifiedMap.newMap(); knownItemsLock = new AutoReadWriteLock(); - expectedUserIDs = HashObjSets.newMutableSet(); + expectedUserIDs = UnifiedSet.newSet(); expectedUserIDsLock = new AutoReadWriteLock(); - expectedItemIDs = HashObjSets.newMutableSet(); + expectedItemIDs = UnifiedSet.newSet(); expectedItemIDsLock = new AutoReadWriteLock(); cachedYTYSolver = new SolverCache(executor, Y); @@ -161,20 +160,20 @@ void setItemVector(String item, float[] vector) { * @return set of known items for the user (immutable, but thread-safe) */ public Set getKnownItems(String user) { - ObjSet knownItems = doGetKnownItems(user); - if (knownItems == null) { + MutableSet knownItemsForUser = doGetKnownItems(user); + if (knownItemsForUser == null) { return Collections.emptySet(); } - synchronized (knownItems) { - if (knownItems.isEmpty()) { + synchronized (knownItemsForUser) { + if (knownItemsForUser.isEmpty()) { return Collections.emptySet(); } // Must copy since the original object is synchronized - return HashObjSets.newImmutableSet(knownItems); + return knownItemsForUser.clone().asUnmodifiable(); } } - private ObjSet doGetKnownItems(String user) { + private MutableSet doGetKnownItems(String user) { try (AutoLock al = knownItemsLock.autoReadLock()) { return knownItems.get(user); } @@ -184,14 +183,15 @@ private ObjSet doGetKnownItems(String user) { * @return mapping of user IDs to count of items the user has interacted with */ public Map getUserCounts() { - ObjIntMap counts = HashObjIntMaps.newUpdatableMap(); + Map counts; try (AutoLock al = knownItemsLock.autoReadLock()) { + counts = new HashMap<>(knownItems.size()); knownItems.forEach((userID, ids) -> { int numItems; synchronized (ids) { numItems = ids.size(); } - counts.addValue(userID, numItems); + counts.put(userID, numItems); }); } return counts; @@ -201,25 +201,28 @@ public Map getUserCounts() { * @return mapping of item IDs to count of users that have interacted with that item */ public Map getItemCounts() { - ObjIntMap counts = HashObjIntMaps.newUpdatableMap(); + ObjectIntHashMap counts = ObjectIntHashMap.newMap(); try (AutoLock al = knownItemsLock.autoReadLock()) { knownItems.values().forEach(ids -> { synchronized (ids) { - ids.forEach(id -> counts.addValue(id, 1)); + ids.forEach(id -> counts.addToValue(id, 1)); } }); } - return counts; + // No way to get Java map from primitive map directly (?) + Map javaCounts = new HashMap<>(counts.size()); + counts.forEachKeyValue(javaCounts::put); + return javaCounts; } void addKnownItems(String user, Collection items) { if (!items.isEmpty()) { - ObjSet knownItemsForUser = doGetKnownItems(user); + MutableSet knownItemsForUser = doGetKnownItems(user); if (knownItemsForUser == null) { try (AutoLock al = knownItemsLock.autoWriteLock()) { // Check again - knownItemsForUser = knownItems.computeIfAbsent(user, k -> HashObjSets.newMutableSet()); + knownItemsForUser = knownItems.computeIfAbsent(user, k -> UnifiedSet.newSet()); } } @@ -238,17 +241,17 @@ public List> getKnownItemVectorsForUser(String user) { if (userVector == null) { return null; } - Collection knownItems = doGetKnownItems(user); - if (knownItems == null) { + Collection knownItemsForUser = doGetKnownItems(user); + if (knownItemsForUser == null) { return null; } - synchronized (knownItems) { - int size = knownItems.size(); + synchronized (knownItemsForUser) { + int size = knownItemsForUser.size(); if (size == 0) { return null; } List> idVectors = new ArrayList<>(size); - for (String itemID : knownItems) { + for (String itemID : knownItemsForUser) { float[] vector = getItemVector(itemID); if (vector != null) { idVectors.add(new Pair<>(itemID, vector)); @@ -260,7 +263,7 @@ public List> getKnownItemVectorsForUser(String user) { public Stream> topN( CosineDistanceSensitiveFunction scoreFn, - ObjDoubleToDoubleFunction rescoreFn, + ToDoubleObjDoubleBiFunction rescoreFn, int howMany, Predicate allowedPredicate) { int[] candidateIndices = lsh.getCandidateIndices(scoreFn.getTargetVector()); @@ -279,7 +282,7 @@ public Stream> topN( * @return all user IDs in the model */ public Collection getAllUserIDs() { - Collection allUserIDs = HashObjSets.newMutableSet(); + Collection allUserIDs = UnifiedSet.newSet(X.size()); X.addAllIDsTo(allUserIDs); return allUserIDs; } @@ -288,7 +291,7 @@ public Collection getAllUserIDs() { * @return all item IDs in the model */ public Collection getAllItemIDs() { - Collection allItemIDs = HashObjSets.newMutableSet(); + Collection allItemIDs = UnifiedSet.newSet(Y.size()); Y.addAllIDsTo(allItemIDs); return allItemIDs; } @@ -345,34 +348,22 @@ void retainRecentAndItemIDs(Collection items) { */ void retainRecentAndKnownItems(Collection users, Collection items) { // Keep all users in the new model, or, that have been added since last model - Collection recentUserIDs = HashObjSets.newMutableSet(); + MutableSet recentUserIDs = UnifiedSet.newSet(); X.addAllRecentTo(recentUserIDs); try (AutoLock al = knownItemsLock.autoWriteLock()) { - knownItems.removeIf((key, value) -> !users.contains(key) && !recentUserIDs.contains(key)); + knownItems.keySet().removeIf(key -> !users.contains(key) && !recentUserIDs.contains(key)); } // This will be easier to quickly copy the whole (smallish) set rather than // deal with locks below - Collection allRecentKnownItems = HashObjSets.newMutableSet(); + MutableSet allRecentKnownItems = UnifiedSet.newSet(); Y.addAllRecentTo(allRecentKnownItems); Predicate notKeptOrRecent = value -> !items.contains(value) && !allRecentKnownItems.contains(value); try (AutoLock al = knownItemsLock.autoReadLock()) { knownItems.values().forEach(knownItemsForUser -> { synchronized (knownItemsForUser) { - // knownItemsForUser.removeIf(notKeptOrRecent); - // TODO remove this temporary hack workaround and restore above - // see https://github.com/OryxProject/oryx/issues/304 - ObjCursor cursor = knownItemsForUser.cursor(); - while (cursor.moveNext()) { - Object o = cursor.elem(); - if (!(o instanceof String)) { - log.warn("Found non-String collection: {}", o); - cursor.remove(); - } else if (notKeptOrRecent.test((String) o)) { - cursor.remove(); - } - } + knownItemsForUser.removeIf(notKeptOrRecent); } }); } diff --git a/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/model/TopNConsumer.java b/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/model/TopNConsumer.java index ab9a44c1e..ee9f26dde 100644 --- a/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/model/TopNConsumer.java +++ b/app/oryx-app-serving/src/main/java/com/cloudera/oryx/app/serving/als/model/TopNConsumer.java @@ -22,27 +22,26 @@ import java.util.function.ToDoubleFunction; import java.util.stream.Stream; -import com.koloboke.function.ObjDoubleToDoubleFunction; - import com.cloudera.oryx.common.collection.Pair; import com.cloudera.oryx.common.collection.Pairs; +import com.cloudera.oryx.common.lang.ToDoubleObjDoubleBiFunction; final class TopNConsumer implements BiConsumer { private static final Predicate ALWAYS_ALLOWED = key -> true; - private static final ObjDoubleToDoubleFunction NO_RESCORE = (key, value) -> value; + private static final ToDoubleObjDoubleBiFunction NO_RESCORE = (key, value) -> value; private final Queue> topN; private final int howMany; private final ToDoubleFunction scoreFn; - private final ObjDoubleToDoubleFunction rescoreFn; + private final ToDoubleObjDoubleBiFunction rescoreFn; private final Predicate allowedPredicate; /** Local copy of lower bound of min score in the priority queue, to avoid polling. */ private double topScoreLowerBound; TopNConsumer(int howMany, ToDoubleFunction scoreFn, - ObjDoubleToDoubleFunction rescoreFn, + ToDoubleObjDoubleBiFunction rescoreFn, Predicate allowedPredicate) { this.topN = new PriorityQueue<>(howMany, Pairs.orderBySecond(Pairs.SortOrder.ASCENDING)); this.howMany = howMany; diff --git a/app/oryx-app/src/main/java/com/cloudera/oryx/app/speed/als/ALSSpeedModel.java b/app/oryx-app/src/main/java/com/cloudera/oryx/app/speed/als/ALSSpeedModel.java index 558fed17f..4a677510b 100644 --- a/app/oryx-app/src/main/java/com/cloudera/oryx/app/speed/als/ALSSpeedModel.java +++ b/app/oryx-app/src/main/java/com/cloudera/oryx/app/speed/als/ALSSpeedModel.java @@ -21,8 +21,8 @@ import com.google.common.base.Preconditions; import com.google.common.util.concurrent.ThreadFactoryBuilder; -import com.koloboke.collect.set.ObjSet; -import com.koloboke.collect.set.hash.HashObjSets; +import org.eclipse.collections.api.set.MutableSet; +import org.eclipse.collections.impl.set.mutable.UnifiedSet; import com.cloudera.oryx.api.speed.SpeedModel; import com.cloudera.oryx.app.als.FeatureVectors; @@ -45,9 +45,9 @@ public final class ALSSpeedModel implements SpeedModel { private final FeatureVectors X; /** Item-feature matrix. */ private final FeatureVectors Y; - private final ObjSet expectedUserIDs; + private final MutableSet expectedUserIDs; private final AutoReadWriteLock expectedUserIDsLock; - private final ObjSet expectedItemIDs; + private final MutableSet expectedItemIDs; private final AutoReadWriteLock expectedItemIDsLock; /** Number of features used in the model. */ private final int features; @@ -71,9 +71,9 @@ public final class ALSSpeedModel implements SpeedModel { int numPartitions = Runtime.getRuntime().availableProcessors(); X = new PartitionedFeatureVectors(numPartitions, executor); Y = new PartitionedFeatureVectors(numPartitions, executor); - expectedUserIDs = HashObjSets.newMutableSet(); + expectedUserIDs = UnifiedSet.newSet(); expectedUserIDsLock = new AutoReadWriteLock(); - expectedItemIDs = HashObjSets.newMutableSet(); + expectedItemIDs = UnifiedSet.newSet(); expectedItemIDsLock = new AutoReadWriteLock(); this.features = features; this.implicit = implicit; diff --git a/framework/oryx-common/src/main/java/com/cloudera/oryx/common/collection/package-info.java b/framework/oryx-common/src/main/java/com/cloudera/oryx/common/collection/package-info.java index edd88878d..c3ffc2d4b 100644 --- a/framework/oryx-common/src/main/java/com/cloudera/oryx/common/collection/package-info.java +++ b/framework/oryx-common/src/main/java/com/cloudera/oryx/common/collection/package-info.java @@ -14,7 +14,6 @@ */ /** - * Collection-related utility methods and support classes, related to the Java Collections API - * and other collections libraries like Koloboke. + * Collection-related utility methods and support classes. */ package com.cloudera.oryx.common.collection; \ No newline at end of file diff --git a/framework/oryx-common/src/main/java/com/cloudera/oryx/common/lang/ToDoubleObjDoubleBiFunction.java b/framework/oryx-common/src/main/java/com/cloudera/oryx/common/lang/ToDoubleObjDoubleBiFunction.java new file mode 100644 index 000000000..4bf018d46 --- /dev/null +++ b/framework/oryx-common/src/main/java/com/cloudera/oryx/common/lang/ToDoubleObjDoubleBiFunction.java @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2018, Cloudera, Inc. All Rights Reserved. + * + * Cloudera, Inc. licenses this file to you under the Apache License, + * Version 2.0 (the "License"). You may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * This software is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR + * CONDITIONS OF ANY KIND, either express or implied. See the License for + * the specific language governing permissions and limitations under the + * License. + */ + +package com.cloudera.oryx.common.lang; + +/** + * Analogous to both {@link java.util.function.ToDoubleBiFunction} and + * {@link java.util.function.ObjDoubleConsumer}, combined. + * + * @param the type of the first argument to the function + */ +@FunctionalInterface +public interface ToDoubleObjDoubleBiFunction { + + /** + * Applies this function to the given arguments. + * + * @param t the first function argument + * @param u the second function argument + * @return the function result + */ + double applyAsDouble(T t, double u); + +} diff --git a/pom.xml b/pom.xml index 3967463cc..23ca5c8f5 100644 --- a/pom.xml +++ b/pom.xml @@ -321,15 +321,14 @@ - com.koloboke - koloboke-api-jdk8 - ${koloboke.version} + org.eclipse.collections + eclipse-collections-api + ${ec.version} - com.koloboke - koloboke-impl-jdk8 - ${koloboke.version} - runtime + org.eclipse.collections + eclipse-collections + ${ec.version} @@ -568,12 +567,12 @@ 2.11 ${scala.minor.version}.12 + 9.2.0 2.7.5 2.8.7 3.0.1 2.27 0.11.0.1 - 1.0.0 1.11.1 1.1.2 1.7.25 @@ -639,7 +638,6 @@ - 3.2.5 @@ -855,17 +853,6 @@ META-INF/*.RSA - - - com.koloboke:* - - com/koloboke/collect/**/*Byte*.class - com/koloboke/collect/**/*Double*.class - com/koloboke/collect/**/*Float*.class - com/koloboke/collect/**/*Short*.class - - - *:*