Skip to content

Commit

Permalink
Fix heuristic cluster refinement (#21)
Browse files Browse the repository at this point in the history
* Adding tests for RefineCluster and fix heuristical refinement

* rename variable

* Add test for Gaussian sum

* rename method

* rename method
  • Loading branch information
SvenLehmann authored Sep 30, 2019
1 parent ea7e079 commit 741df81
Show file tree
Hide file tree
Showing 4 changed files with 338 additions and 121 deletions.
2 changes: 1 addition & 1 deletion build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ subprojects {
"testImplementation"("org.junit.jupiter:junit-jupiter-api:5.3.0")
"testImplementation"("org.junit.jupiter:junit-jupiter-params:5.3.0")
"testRuntimeOnly"("org.junit.jupiter:junit-jupiter-engine:5.3.0")
"testImplementation"(group = "org.assertj", name = "assertj-core", version = "3.11.1")
"testImplementation"(group = "org.assertj", name = "assertj-core", version = "3.13.0")

"compileOnly"("org.projectlombok:lombok:1.18.6")
"annotationProcessor"("org.projectlombok:lombok:1.18.6")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ private Optional<ClassificationResult> evaluateRule(final @NonNull Rule<? super
* <p>A positive rule returns a positive score [0; 1] and results in a DUPLICATE.</p>
*/
private ClassificationResult mapScoreToResult(final @NonNull Rule<? super T> rule, final double score) {
if (RuleBasedClassifier.didNotApply(score)) {
if (didNotApply(score)) {
return UNKNOWN;
}
if (score <= -0.0d) {
Expand Down
185 changes: 100 additions & 85 deletions common/src/main/java/com/bakdata/dedupe/clustering/RefineCluster.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import java.util.stream.StreamSupport;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.Value;
import lombok.experimental.FieldDefaults;
Expand Down Expand Up @@ -81,7 +82,7 @@ public class RefineCluster<C extends Comparable<C>, T> {
* (max - 1) / 2}.
*/
@Builder.Default
int maxSmallClusterSize = 10;
final int maxSmallClusterSize = 10;
/**
* The classifier used to score the edges. Please note that binary classifiers (confidence always 1) can be used but
* will not unleash the full potential.
Expand All @@ -107,10 +108,6 @@ private static double getWeight(final ClassificationResult classificationResult)
}
}

private static int getNumEdges(final int n) {
return n * (n - 1) / 2;
}

private static double scoreClustering(final byte[] partitions, final double[][] weightMatrix) {
final int n = partitions.length;
final int[] partitionSizes = new int[n];
Expand All @@ -121,11 +118,12 @@ private static double scoreClustering(final byte[] partitions, final double[][]
double score = 0;
for (int rowIndex = 0; rowIndex < n; rowIndex++) {
for (int colIndex = rowIndex + 1; colIndex < n; colIndex++) {
final double weightForEdge = weightMatrix[rowIndex][colIndex];
if (partitions[rowIndex] == partitions[colIndex]) {
score += weightMatrix[rowIndex][colIndex] / partitionSizes[partitions[rowIndex]];
score += weightForEdge / partitionSizes[partitions[rowIndex]];
} else {
score -= weightMatrix[rowIndex][colIndex] / (n - partitionSizes[partitions[rowIndex]]) +
weightMatrix[rowIndex][colIndex] / (n - partitionSizes[partitions[colIndex]]);
score -= weightForEdge / (n - partitionSizes[partitions[rowIndex]]) +
weightForEdge / (n - partitionSizes[partitions[colIndex]]);
}
}
}
Expand All @@ -134,12 +132,12 @@ private static double scoreClustering(final byte[] partitions, final double[][]

static List<WeightedEdge> getRandomEdges(final int potentialNumEdges, final int desiredNumEdges) {
return RANDOM.ints(0, potentialNumEdges)
.distinct()
.mapToObj(RefineCluster::createGaussPair)
.filter(RefineCluster::isNotSelfPair)
.map(p -> WeightedEdge.of(p.getLeft(), p.getRight(), Double.NaN))
.limit(desiredNumEdges)
.collect(Collectors.toList());
.distinct()
.mapToObj(RefineCluster::createGaussPair)
.filter(RefineCluster::isNotSelfPair)
.map(p -> WeightedEdge.of(p.getLeft(), p.getRight(), Double.NaN))
.limit(desiredNumEdges)
.collect(Collectors.toList());
}

private static <T> boolean isNotSelfPair(final Pair<T, T> pair) {
Expand All @@ -149,10 +147,14 @@ private static <T> boolean isNotSelfPair(final Pair<T, T> pair) {
static Pair<Integer, Integer> createGaussPair(final int i) {
// reverse of Gaussian
final int leftIndex = (int) (Math.sqrt(2 * i + 0.25) - 0.5);
final int rightIndex = i - getNumEdges(leftIndex + 1);
final int rightIndex = i - triangularNumber(leftIndex);
return Pair.of(leftIndex, rightIndex);
}

static int triangularNumber(final int n) {
return (n + 1) * (n) / 2;
}

private List<ClassifiedCandidate<T>> getRelevantClassifications(final Cluster<C, ? super T> cluster,
final @NonNull Map<T, List<ClassifiedCandidate<T>>> relevantClassificationIndex) {
return cluster.getElements().stream()
Expand All @@ -179,9 +181,10 @@ private Map<T, List<ClassifiedCandidate<T>>> getRelevantClassificationIndex(
private byte[] refineBigCluster(final @NonNull Cluster<C, T> cluster,
final @NonNull Collection<ClassifiedCandidate<T>> knownClassifications) {
final List<WeightedEdge> duplicates = this.toWeightedEdges(knownClassifications, cluster);
final int desiredNumEdges = getNumEdges(this.maxSmallClusterSize);
final int desiredNumEdges = triangularNumber(this.maxSmallClusterSize);

return this.greedyCluster(cluster, this.getWeightedEdges(cluster, duplicates, desiredNumEdges));
final GreedyClustering<C, T> greedyClustering = new GreedyClustering<>();
return greedyClustering.greedyCluster(cluster, this.getWeightedEdges(cluster, duplicates, desiredNumEdges));
}

/**
Expand Down Expand Up @@ -221,44 +224,48 @@ private List<WeightedEdge> toWeightedEdges(final Collection<ClassifiedCandidate<
IntStream.range(0, cluster.size()).boxed().collect(Collectors.toMap(cluster::get, i -> i));

return knownClassifications.stream()
.map(knownClassification -> WeightedEdge.of(clusterIndex.get(
knownClassification.getCandidate().getRecord1()),
.map(knownClassification -> WeightedEdge.of(
clusterIndex.get(knownClassification.getCandidate().getRecord1()),
clusterIndex.get(knownClassification.getCandidate().getRecord2()),
getWeight(knownClassification.getClassificationResult())))
.collect(Collectors.toList());
}

private Stream<Cluster<C, T>> refineCluster(final Cluster<C, T> cluster,
final @NonNull List<ClassifiedCandidate<T>> knownClassifications) {
final @NonNull Collection<ClassifiedCandidate<T>> knownClassifications) {
if (cluster.size() <= 2) {
return Stream.of(cluster);
}

final byte[] bestClustering;
final byte[] bestClustering = this.getBestClustering(cluster, knownClassifications);
return this.getSubClusters(bestClustering, cluster);
}

private byte[] getBestClustering(final Cluster<C, T> cluster,
final @NonNull Collection<ClassifiedCandidate<T>> knownClassifications) {
if (cluster.size() > this.maxSmallClusterSize) {
// large cluster with high probability of error
bestClustering = this.refineBigCluster(cluster, knownClassifications);

} else {
bestClustering = this.refineSmallCluster(cluster, knownClassifications);
return this.refineBigCluster(cluster, knownClassifications);
}

return this.getSubClusters(bestClustering, cluster);
return this.refineSmallCluster(cluster, knownClassifications);
}

private @NonNull double[][] getKnownWeightMatrix(final Cluster<C, T> cluster,
private @NonNull double[][] getKnownWeightMatrix(final Cluster<C, ? extends T> cluster,
final @NonNull Iterable<ClassifiedCandidate<T>> knownClassifications) {
final int n = cluster.size();
final double[][] weightMatrix = new double[n][n];
for (final double[] row : weightMatrix) {
Arrays.fill(row, Double.NaN);
}

final Map<T, Integer> clusterIndex =
IntStream.range(0, n).boxed().collect(Collectors.toMap(cluster::get, i -> i));
final Map<T, Integer> clusterIndex = IntStream.range(0, n)
.boxed()
.collect(Collectors.toMap(cluster::get, i -> i));

for (final ClassifiedCandidate<T> knownClassification : knownClassifications) {
final Integer firstIndex = clusterIndex.get(knownClassification.getCandidate().getRecord1());
final Integer secondIndex = clusterIndex.get(knownClassification.getCandidate().getRecord2());
final int firstIndex = clusterIndex.get(knownClassification.getCandidate().getRecord1());
final int secondIndex = clusterIndex.get(knownClassification.getCandidate().getRecord2());
weightMatrix[Math.min(firstIndex, secondIndex)][Math.max(firstIndex, secondIndex)] =
getWeight(knownClassification.getClassificationResult());
}
Expand All @@ -275,81 +282,54 @@ private Stream<Cluster<C, T>> getSubClusters(final byte[] bestClustering,
.map(records -> new Cluster<>(this.clusterIdGenerator.apply(records), records));
}

private byte[] greedyCluster(final Cluster<C, T> cluster, final @NonNull Collection<? extends WeightedEdge> edges) {

final Collection<WeightedEdge> queue = new PriorityQueue<>(Comparator.comparing(WeightedEdge::getWeight));
queue.addAll(edges);

final double[][] weightMatrix = new double[cluster.size()][cluster.size()];
for (final WeightedEdge edge : edges) {
weightMatrix[edge.left][edge.right] = edge.getWeight();
}

// start with each publication in its own cluster
byte[] clustering = Bytes.toArray(IntStream.range(0, cluster.size()).boxed().collect(Collectors.toList()));
double score = 0;
for (final WeightedEdge edge : queue) {
final byte[] newClustering = clustering.clone();
final byte newClusterId = newClustering[edge.left];
final byte oldClusterId = newClustering[edge.right];
for (int i = 0; i < newClustering.length; i++) {
if (newClustering[i] == oldClusterId) {
newClustering[i] = newClusterId;
}
}
final double newScore = scoreClustering(newClustering, weightMatrix);
if (newScore > score) {
score = newScore;
clustering = newClustering;
}
}
return clustering;
}

private List<WeightedEdge> addRandomEdges(final @NonNull List<? extends WeightedEdge> edges,
final int desiredNumEdges) {
// add random edges with distance 2..n of known edges (e.g., neighbors of known edges).
List<WeightedEdge> lastAddedEdges;
final Set<WeightedEdge> weightedEdges = new LinkedHashSet<>(edges);
for (int distance = 2; distance < this.maxSmallClusterSize && weightedEdges.size() < desiredNumEdges;
distance++) {
lastAddedEdges = edges.stream()
.flatMap(e1 -> edges.stream().filter(e1::overlaps).map(e1::getTriangleEdge))
.filter(e -> !weightedEdges.contains(e))
// add random edges with distance 2..n of known edges (e.g., neighbors of known edges).
final List<WeightedEdge> lastAddedEdges = edges.stream()
.flatMap(edge -> edges.stream().filter(edge::overlaps).map(edge::getTriangleEdge))
.filter(edge -> !weightedEdges.contains(edge))
.limit((long) desiredNumEdges - edges.size())
.collect(Collectors.toList());
weightedEdges.addAll(lastAddedEdges);
Collections.shuffle(lastAddedEdges);
weightedEdges.addAll(lastAddedEdges);
}
if (weightedEdges.size() < desiredNumEdges) {
throw new IllegalStateException("We have a connected components, so we should get a fully connected graph");
throw new IllegalStateException("We have a connected component, so we should get a fully connected graph");
}
return new ArrayList<>(weightedEdges);
}

private List<WeightedEdge> getWeightedEdges(final @NonNull Cluster<C, ? extends T> cluster,
final List<? extends WeightedEdge> duplicates,
final int desiredNumEdges) {
final List<WeightedEdge> weightedEdges;
final List<WeightedEdge> edges = this.getEdges(cluster, duplicates, desiredNumEdges);

return edges.stream()
.map(edge -> this.calculateWeightIfNeeded(cluster, edge))
.collect(Collectors.toList());
}

private List<WeightedEdge> getEdges(final @NonNull Cluster<C, ? extends T> cluster,
final List<? extends WeightedEdge> duplicates, final int desiredNumEdges) {
if (duplicates.isEmpty()) {
final int n = cluster.size();
weightedEdges = getRandomEdges(getNumEdges(n), desiredNumEdges);
} else {
Collections.shuffle(duplicates);
weightedEdges = this.addRandomEdges(duplicates, desiredNumEdges);
return getRandomEdges(triangularNumber(n), desiredNumEdges);
}

return weightedEdges.stream()
.map(weightedEdge -> calculateWeightIfNeeded(cluster, weightedEdge))
.collect(Collectors.toList());
Collections.shuffle(duplicates);
return this.addRandomEdges(duplicates, desiredNumEdges);
}

private WeightedEdge calculateWeightIfNeeded(@NonNull Cluster<C, ? extends T> cluster, WeightedEdge weightedEdge) {
double weight = weightedEdge.getWeight();
private WeightedEdge calculateWeightIfNeeded(final @NonNull Cluster<C, ? extends T> cluster,
final WeightedEdge weightedEdge) {
final double weight = weightedEdge.getWeight();
if (Double.isNaN(weight)) {
// calculate weight for dummy entry
T left = cluster.get(weightedEdge.getLeft());
T right = cluster.get(weightedEdge.getRight());
final T left = cluster.get(weightedEdge.getLeft());
final T right = cluster.get(weightedEdge.getRight());
return weightedEdge.withWeight(getWeight(this.classifier.classify(new OnlineCandidate<>(left, right))));
}
return weightedEdge;
Expand All @@ -361,7 +341,7 @@ private static final class ClusteringGenerator implements Iterator<byte[]> {
final @NonNull byte[] clustering;
boolean hasNext = true;

ClusteringGenerator(final byte n) {
private ClusteringGenerator(final byte n) {
this.n = n;
this.clustering = new byte[n];
}
Expand Down Expand Up @@ -402,8 +382,10 @@ private boolean incrementWouldResultInSkippedInteger(final byte i) {
}

@Value
private static class WeightedEdge {
static class WeightedEdge {
@Getter
int left;
@Getter
int right;
@Wither
double weight;
Expand All @@ -412,7 +394,7 @@ static WeightedEdge of(final int leftIndex, final int rightIndex, final double w
return new WeightedEdge(Math.min(leftIndex, rightIndex), Math.max(leftIndex, rightIndex), weight);
}

@NonNull WeightedEdge getTriangleEdge(final @NonNull WeightedEdge e) {
private @NonNull WeightedEdge getTriangleEdge(final @NonNull WeightedEdge e) {
if (this.left < e.left) {
return new WeightedEdge(this.left, e.left + e.right - this.right, Double.NaN);
} else if (this.left == e.left) {
Expand All @@ -421,11 +403,44 @@ static WeightedEdge of(final int leftIndex, final int rightIndex, final double w
return new WeightedEdge(e.left, this.left + this.getRight() - e.getRight(), Double.NaN);
}

boolean overlaps(final @NonNull WeightedEdge e) {
private boolean overlaps(final @NonNull WeightedEdge e) {
return e.getLeft() == this.getLeft() || e.getLeft() == this.getRight() || e.getRight() == this.getLeft()
|| e.getRight() == this
.getRight();
}
}

static class GreedyClustering<C extends Comparable<C>, T> {

byte[] greedyCluster(final Cluster<C, T> cluster, final @NonNull Collection<? extends WeightedEdge> edges) {

final Collection<WeightedEdge> queue = new PriorityQueue<>(Comparator.comparing(WeightedEdge::getWeight));
queue.addAll(edges);

final double[][] weightMatrix = new double[cluster.size()][cluster.size()];
for (final WeightedEdge edge : edges) {
weightMatrix[edge.getLeft()][edge.getRight()] = edge.getWeight();
}

// start with each publication in its own cluster
byte[] clustering = Bytes.toArray(IntStream.range(0, cluster.size()).boxed().collect(Collectors.toList()));
double score = scoreClustering(clustering, weightMatrix);
for (final WeightedEdge edge : queue) {
final byte[] newClustering = clustering.clone();
final byte newClusterId = newClustering[edge.getLeft()];
final byte oldClusterId = newClustering[edge.getRight()];
for (int i = 0; i < newClustering.length; i++) {
if (newClustering[i] == oldClusterId) {
newClustering[i] = newClusterId;
}
}
final double newScore = scoreClustering(newClustering, weightMatrix);
if (newScore > score) {
score = newScore;
clustering = newClustering;
}
}
return clustering;
}
}
}
Loading

0 comments on commit 741df81

Please sign in to comment.