Skip to content
This repository has been archived by the owner on Jun 6, 2023. It is now read-only.

Commit

Permalink
Swap Koloboke for Eclipse Collections; see OryxProject#304
Browse files Browse the repository at this point in the history
  • Loading branch information
srowen committed Sep 27, 2018
1 parent 3212628 commit f85cf3c
Show file tree
Hide file tree
Showing 16 changed files with 170 additions and 156 deletions.
18 changes: 18 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,21 @@ Redistribution and use in source and binary forms, with or without modification,
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

========================================================================
Eclipse Collections (https://github.com/eclipse/eclipse-collections)
========================================================================

Eclipse Distribution License - v 1.0

Copyright (c) 2007, Eclipse Foundation, Inc. and its licensors.

All rights reserved.

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
Neither the name of the Eclipse Foundation, Inc. nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
9 changes: 4 additions & 5 deletions app/oryx-app-common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,12 @@
<artifactId>hadoop-client</artifactId>
</dependency>
<dependency>
<groupId>com.koloboke</groupId>
<artifactId>koloboke-api-jdk8</artifactId>
<groupId>org.eclipse.collections</groupId>
<artifactId>eclipse-collections-api</artifactId>
</dependency>
<dependency>
<groupId>com.koloboke</groupId>
<artifactId>koloboke-impl-jdk8</artifactId>
<scope>runtime</scope>
<groupId>org.eclipse.collections</groupId>
<artifactId>eclipse-collections</artifactId>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
import java.util.Collection;
import java.util.function.BiConsumer;

import com.koloboke.collect.map.ObjObjMap;
import com.koloboke.collect.map.hash.HashObjObjMaps;
import com.koloboke.collect.set.ObjSet;
import com.koloboke.collect.set.hash.HashObjSets;
import org.eclipse.collections.api.set.MutableSet;
import org.eclipse.collections.impl.map.mutable.UnifiedMap;
import org.eclipse.collections.impl.set.mutable.UnifiedSet;

import com.cloudera.oryx.common.lang.AutoLock;
import com.cloudera.oryx.common.lang.AutoReadWriteLock;
Expand All @@ -35,13 +34,13 @@
*/
public final class FeatureVectorsPartition implements FeatureVectors {

private final ObjObjMap<String,float[]> vectors;
private final ObjSet<String> recentIDs;
private final UnifiedMap<String,float[]> vectors;
private final MutableSet<String> recentIDs;
private final AutoReadWriteLock lock;

public FeatureVectorsPartition() {
vectors = HashObjObjMaps.newMutableMap();
recentIDs = HashObjSets.newMutableSet();
vectors = UnifiedMap.newMap();
recentIDs = UnifiedSet.newSet();
lock = new AutoReadWriteLock();
}

Expand Down Expand Up @@ -101,7 +100,7 @@ public void addAllRecentTo(Collection<String> allRecent) {
@Override
public void retainRecentAndIDs(Collection<String> newModelIDs) {
try (AutoLock al = lock.autoWriteLock()) {
vectors.removeIf((key, value) -> !newModelIDs.contains(key) && !recentIDs.contains(key));
vectors.keySet().removeIf(key -> !newModelIDs.contains(key) && !recentIDs.contains(key));
recentIDs.clear();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
import java.util.stream.Stream;

import com.google.common.base.Preconditions;
import com.koloboke.collect.map.ObjIntMap;
import com.koloboke.collect.map.hash.HashObjIntMaps;
import org.eclipse.collections.impl.map.mutable.primitive.ObjectIntHashMap;

import com.cloudera.oryx.common.lang.AutoLock;
import com.cloudera.oryx.common.lang.AutoReadWriteLock;
Expand All @@ -45,7 +44,7 @@ public final class PartitionedFeatureVectors implements FeatureVectors {
private final FeatureVectorsPartition[] partitions;
private final ToIntBiFunction<String,float[]> partitioner;
/** Maps item IDs to their existing partition, if any */
private final ObjIntMap<String> partitionMap;
private final ObjectIntHashMap<String> partitionMap;
/** Controls access to yPartitionMap. */
private final AutoReadWriteLock partitionMapLock;
private final ExecutorService executor;
Expand All @@ -66,7 +65,7 @@ public PartitionedFeatureVectors(int numPartitions,
for (int i = 0; i < numPartitions; i++) {
partitions[i] = new FeatureVectorsPartition();
}
partitionMap = HashObjIntMaps.newMutableMap();
partitionMap = ObjectIntHashMap.newMap();
partitionMapLock = new AutoReadWriteLock();
this.partitioner = partitioner;
this.executor = executor;
Expand Down Expand Up @@ -152,7 +151,7 @@ public float[] getVector(String item) {
// the partitioner might change its answer over time.
int partition;
try (AutoLock al = partitionMapLock.autoReadLock()) {
partition = partitionMap.getOrDefault(item, Integer.MIN_VALUE);
partition = partitionMap.getIfAbsent(item, Integer.MIN_VALUE);
}
if (partition < 0) {
return null;
Expand All @@ -165,7 +164,7 @@ public void setVector(String item, float[] vector) {
int newPartition = partitioner.applyAsInt(item, vector);
// Exclusive update to mapping -- careful since other locks are acquired inside here
try (AutoLock al = partitionMapLock.autoWriteLock()) {
int existingPartition = partitionMap.getOrDefault(item, Integer.MIN_VALUE);
int existingPartition = partitionMap.getIfAbsent(item, Integer.MIN_VALUE);
if (existingPartition >= 0 && existingPartition != newPartition) {
// Move from one to the other partition, so first remove old entry
partitions[existingPartition].removeVector(item);
Expand Down Expand Up @@ -221,7 +220,7 @@ public double[] getVTV(boolean background) {
@Override
public String toString() {
int maxSize = 16;
List<String> partitionSizes = new ArrayList<>(maxSize);
Collection<String> partitionSizes = new ArrayList<>(maxSize);
for (int i = 0; i < partitions.length; i++) {
int size = partitions[i].size();
if (size > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
import java.util.stream.IntStream;

import com.google.common.base.Preconditions;
import com.koloboke.collect.map.IntLongMap;
import com.koloboke.collect.map.hash.HashIntLongMaps;
import com.typesafe.config.Config;
import org.apache.hadoop.fs.Path;
import org.apache.spark.api.java.JavaRDD;
Expand Down Expand Up @@ -61,6 +59,8 @@
import org.dmg.pmml.mining.Segmentation;
import org.dmg.pmml.tree.Node;
import org.dmg.pmml.tree.TreeModel;
import org.eclipse.collections.api.map.primitive.IntLongMap;
import org.eclipse.collections.impl.map.mutable.primitive.IntLongHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.collection.JavaConversions;
Expand Down Expand Up @@ -162,8 +162,8 @@ public PMML buildModel(JavaSparkContext sparkContext,
seed);
}

List<Map<Integer,Long>> treeNodeIDCounts = treeNodeExampleCounts(trainPointData, model);
Map<Integer,Long> predictorIndexCounts = predictorExampleCounts(trainPointData, model);
List<IntLongHashMap> treeNodeIDCounts = treeNodeExampleCounts(trainPointData, model);
IntLongHashMap predictorIndexCounts = predictorExampleCounts(trainPointData, model);

return rdfModelToPMML(model,
categoricalValueEncodings,
Expand Down Expand Up @@ -264,31 +264,30 @@ private JavaRDD<LabeledPoint> parseToLabeledPointRDD(
* per tree in the model
* @see #predictorExampleCounts(JavaRDD,RandomForestModel)
*/
private static List<Map<Integer,Long>> treeNodeExampleCounts(JavaRDD<? extends LabeledPoint> trainPointData,
RandomForestModel model) {
private static List<IntLongHashMap> treeNodeExampleCounts(JavaRDD<? extends LabeledPoint> trainPointData,
RandomForestModel model) {
return trainPointData.mapPartitions(data -> {
DecisionTreeModel[] trees = model.trees();
List<IntLongMap> treeNodeIDCounts = IntStream.range(0, trees.length).
mapToObj(i -> HashIntLongMaps.newMutableMap()).collect(Collectors.toList());
List<IntLongHashMap> treeNodeIDCounts = IntStream.range(0, trees.length).
mapToObj(i -> new IntLongHashMap()).collect(Collectors.toList());
data.forEachRemaining(datum -> {
double[] featureVector = datum.features().toArray();
for (int i = 0; i < trees.length; i++) {
DecisionTreeModel tree = trees[i];
IntLongMap nodeIDCount = treeNodeIDCounts.get(i);
IntLongHashMap nodeIDCount = treeNodeIDCounts.get(i);
org.apache.spark.mllib.tree.model.Node node = tree.topNode();
// This logic cloned from Node.predict:
while (!node.isLeaf()) {
// Count node ID
nodeIDCount.addValue(node.id(), 1);
nodeIDCount.addToValue(node.id(), 1);
Split split = node.split().get();
int featureIndex = split.feature();
node = nextNode(featureVector, node, split, featureIndex);
}
nodeIDCount.addValue(node.id(), 1);
nodeIDCount.addToValue(node.id(), 1);
}
});
return Collections.<List<Map<Integer,Long>>>singleton(
treeNodeIDCounts.stream().map(HashMap::new).collect(Collectors.toList())).iterator();
return Collections.singleton(treeNodeIDCounts).iterator();
}
).reduce((a, b) -> {
Preconditions.checkArgument(a.size() == b.size());
Expand All @@ -307,10 +306,10 @@ private static List<Map<Integer,Long>> treeNodeExampleCounts(JavaRDD<? extends L
* features, since there are fewer predictors than features. That is, the index will
* match the one used in the {@link RandomForestModel}.
*/
private static Map<Integer,Long> predictorExampleCounts(JavaRDD<? extends LabeledPoint> trainPointData,
RandomForestModel model) {
private static IntLongHashMap predictorExampleCounts(JavaRDD<? extends LabeledPoint> trainPointData,
RandomForestModel model) {
return trainPointData.mapPartitions(data -> {
IntLongMap featureIndexCount = HashIntLongMaps.newMutableMap();
IntLongHashMap featureIndexCount = new IntLongHashMap();
data.forEachRemaining(datum -> {
double[] featureVector = datum.features().toArray();
for (DecisionTreeModel tree : model.trees()) {
Expand All @@ -320,14 +319,12 @@ private static Map<Integer,Long> predictorExampleCounts(JavaRDD<? extends Labele
Split split = node.split().get();
int featureIndex = split.feature();
// Count feature
featureIndexCount.addValue(featureIndex, 1);
featureIndexCount.addToValue(featureIndex, 1);
node = nextNode(featureVector, node, split, featureIndex);
}
}
});
// Clone to avoid problem with Kryo serializing Koloboke
return Collections.<Map<Integer,Long>>singleton(
new HashMap<>(featureIndexCount)).iterator();
return Collections.singleton(featureIndexCount).iterator();
}).reduce(RDFUpdate::merge);
}

Expand All @@ -352,26 +349,21 @@ private static org.apache.spark.mllib.tree.model.Node nextNode(
}
}

private static <T> Map<T,Long> merge(Map<T,Long> a, Map<T,Long> b) {
private static IntLongHashMap merge(IntLongHashMap a, IntLongHashMap b) {
if (b.size() > a.size()) {
return merge(b, a);
}
b.forEach((key, value) -> a.merge(key, value, (x, y) -> x + y));
b.forEachKeyValue(a::addToValue);
return a;
}

private static <T> long get(Map<T,Long> map, T key) {
Long count = map.get(key);
return count == null ? 0L : count;
}

private PMML rdfModelToPMML(RandomForestModel rfModel,
CategoricalValueEncodings categoricalValueEncodings,
int maxDepth,
int maxSplitCandidates,
String impurity,
List<? extends Map<Integer,Long>> nodeIDCounts,
Map<Integer,Long> predictorIndexCounts) {
List<? extends IntLongMap> nodeIDCounts,
IntLongMap predictorIndexCounts) {

boolean classificationTask = rfModel.algo().equals(Algo.Classification());
Preconditions.checkState(classificationTask == inputSchema.isClassification());
Expand Down Expand Up @@ -422,7 +414,7 @@ private PMML rdfModelToPMML(RandomForestModel rfModel,

private TreeModel toTreeModel(DecisionTreeModel dtModel,
CategoricalValueEncodings categoricalValueEncodings,
Map<Integer,Long> nodeIDCounts) {
IntLongMap nodeIDCounts) {

boolean classificationTask = dtModel.algo().equals(Algo.Classification());
Preconditions.checkState(classificationTask == inputSchema.isClassification());
Expand All @@ -448,7 +440,7 @@ private TreeModel toTreeModel(DecisionTreeModel dtModel,
modelNode.setPredicate(predicate);

org.apache.spark.mllib.tree.model.Node treeNode = treeNodePredicate.getFirst();
long nodeCount = get(nodeIDCounts, treeNode.id());
long nodeCount = nodeIDCounts.get(treeNode.id());
modelNode.setRecordCount((double) nodeCount);

if (treeNode.isLeaf()) {
Expand Down Expand Up @@ -492,7 +484,7 @@ private TreeModel toTreeModel(DecisionTreeModel dtModel,
org.apache.spark.mllib.tree.model.Node rightTreeNode = treeNode.rightNode().get();
org.apache.spark.mllib.tree.model.Node leftTreeNode = treeNode.leftNode().get();

boolean defaultRight = get(nodeIDCounts, rightTreeNode.id()) > get(nodeIDCounts, leftTreeNode.id());
boolean defaultRight = nodeIDCounts.get(rightTreeNode.id()) > nodeIDCounts.get(leftTreeNode.id());
modelNode.setDefaultChild(defaultRight ? positiveModelNode.getId() : negativeModelNode.getId());

// Right node is "positive", so carries the predicate. It must evaluate first
Expand Down Expand Up @@ -528,7 +520,7 @@ private Predicate buildPredicate(Split split,
// So the predicate will evaluate "not in" this set
// More ugly casting
@SuppressWarnings("unchecked")
List<Double> javaCategories = (List<Double>) (List<?>)
Collection<Double> javaCategories = (Collection<Double>) (Collection<?>)
JavaConversions.seqAsJavaList(split.categories());
Set<Integer> negativeEncodings = javaCategories.stream().map(Double::intValue).collect(Collectors.toSet());

Expand All @@ -548,11 +540,11 @@ private Predicate buildPredicate(Split split,
}
}

private double[] countsToImportances(Map<Integer,Long> predictorIndexCounts) {
private double[] countsToImportances(IntLongMap predictorIndexCounts) {
double[] importances = new double[inputSchema.getNumPredictors()];
long total = predictorIndexCounts.values().stream().mapToLong(l -> l).sum();
long total = predictorIndexCounts.sum();
Preconditions.checkArgument(total > 0);
predictorIndexCounts.forEach((k, count) -> importances[k] = (double) count / total);
predictorIndexCounts.forEachKeyValue((k, count) -> importances[k] = (double) count / total);
return importances;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,13 @@
import javax.ws.rs.QueryParam;
import javax.ws.rs.core.MediaType;

import com.koloboke.function.ObjDoubleToDoubleFunction;

import com.cloudera.oryx.api.serving.OryxServingException;
import com.cloudera.oryx.app.als.Rescorer;
import com.cloudera.oryx.app.als.RescorerProvider;
import com.cloudera.oryx.app.serving.IDValue;
import com.cloudera.oryx.app.serving.als.model.ALSServingModel;
import com.cloudera.oryx.common.collection.Pair;
import com.cloudera.oryx.common.lang.ToDoubleObjDoubleBiFunction;

/**
* <p>Responds to a GET request to
Expand Down Expand Up @@ -92,7 +91,7 @@ public List<IDValue> get(
}
}

ObjDoubleToDoubleFunction<String> rescoreFn = null;
ToDoubleObjDoubleBiFunction<String> rescoreFn = null;
RescorerProvider rescorerProvider = getALSServingModel().getRescorerProvider();
if (rescorerProvider != null) {
Rescorer rescorer = rescorerProvider.getRecommendRescorer(Collections.singletonList(userID),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.PathSegment;

import com.koloboke.function.ObjDoubleToDoubleFunction;

import com.cloudera.oryx.api.serving.OryxServingException;
import com.cloudera.oryx.app.als.Rescorer;
import com.cloudera.oryx.app.als.RescorerProvider;
import com.cloudera.oryx.app.serving.IDValue;
import com.cloudera.oryx.app.serving.als.model.ALSServingModel;
import com.cloudera.oryx.common.collection.Pair;
import com.cloudera.oryx.common.lang.ToDoubleObjDoubleBiFunction;

/**
* <p>Responds to a GET request to
Expand Down Expand Up @@ -79,7 +78,7 @@ public List<IDValue> get(

Collection<String> knownItemsSet = new HashSet<>(knownItems);
Predicate<String> allowedFn = v -> !knownItemsSet.contains(v);
ObjDoubleToDoubleFunction<String> rescoreFn = null;
ToDoubleObjDoubleBiFunction<String> rescoreFn = null;
RescorerProvider rescorerProvider = getALSServingModel().getRescorerProvider();
if (rescorerProvider != null) {
Rescorer rescorer = rescorerProvider.getRecommendToAnonymousRescorer(knownItems,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,13 @@
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.PathSegment;

import com.koloboke.function.ObjDoubleToDoubleFunction;

import com.cloudera.oryx.api.serving.OryxServingException;
import com.cloudera.oryx.app.als.Rescorer;
import com.cloudera.oryx.app.als.RescorerProvider;
import com.cloudera.oryx.app.serving.IDValue;
import com.cloudera.oryx.app.serving.als.model.ALSServingModel;
import com.cloudera.oryx.common.collection.Pair;
import com.cloudera.oryx.common.lang.ToDoubleObjDoubleBiFunction;

/**
* <p>Responds to a GET request to
Expand Down Expand Up @@ -90,7 +89,7 @@ public List<IDValue> get(
allowedFn = v -> !userKnownItems.contains(v);
}

ObjDoubleToDoubleFunction<String> rescoreFn = null;
ToDoubleObjDoubleBiFunction<String> rescoreFn = null;
RescorerProvider rescorerProvider = getALSServingModel().getRescorerProvider();
if (rescorerProvider != null) {
Rescorer rescorer = rescorerProvider.getRecommendRescorer(userIDs, rescorerParams);
Expand Down
Loading

0 comments on commit f85cf3c

Please sign in to comment.