diff --git a/algo/src/main/java/org/neo4j/graphalgo/similarity/Similarities.java b/algo/src/main/java/org/neo4j/graphalgo/similarity/Similarities.java index fa97b9925..1ac68a60e 100644 --- a/algo/src/main/java/org/neo4j/graphalgo/similarity/Similarities.java +++ b/algo/src/main/java/org/neo4j/graphalgo/similarity/Similarities.java @@ -45,10 +45,14 @@ public double jaccardSimilarity(@Name("vector1") List vector1, @Name("ve if (vector1 == null || vector2 == null) return 0; HashSet intersectionSet = new HashSet<>(vector1); + + // add size of vector1 and vector2 (ignoring duplicates) before calling retainAll(vector2) + long denom_sum = intersectionSet.size() + new HashSet<>(vector2).size(); + intersectionSet.retainAll(vector2); int intersection = intersectionSet.size(); - long denominator = vector1.size() + vector2.size() - intersection; + long denominator = denom_sum - intersection; return denominator == 0 ? 0 : (double) intersection / denominator; } @@ -172,10 +176,14 @@ public double overlapSimilarity(@Name("vector1") List vector1, @Name("ve if (vector1 == null || vector2 == null) return 0; HashSet intersectionSet = new HashSet<>(vector1); + + long size1 = intersectionSet.size(); + long size2 = new HashSet<>(vector2).size(); + intersectionSet.retainAll(vector2); int intersection = intersectionSet.size(); - long denominator = Math.min(vector1.size(), vector2.size()); + long denominator = Math.min(size1, size2); return denominator == 0 ? 0 : (double) intersection / denominator; } diff --git a/algo/src/test/java/org/neo4j/graphalgo/similarity/SimilaritiesTest.java b/algo/src/test/java/org/neo4j/graphalgo/similarity/SimilaritiesTest.java new file mode 100644 index 000000000..1e70469bb --- /dev/null +++ b/algo/src/test/java/org/neo4j/graphalgo/similarity/SimilaritiesTest.java @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2017 "Neo4j, Inc." + * + * This file is part of Neo4j Graph Algorithms . + * + * Neo4j Graph Algorithms is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ +package org.neo4j.graphalgo.similarity; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.*; + +@RunWith(Parameterized.class) +public class SimilaritiesTest { + + private final List input; + + @Parameterized.Parameters(name = "{0}") + public static Collection> data() { + return Arrays.asList( + Arrays.asList(1, 2, 3), + Arrays.asList(1, 2, 3, 3), + Arrays.asList(104, 101, 108, 108, 111) + ); + } + + public SimilaritiesTest(List input) { + this.input = input; + } + + @Test + public void jaccardIdenticalInput() { + // given identical input + + // when + Similarities s = new Similarities(); + double result = s.jaccardSimilarity(input, input); + + // then + assertEquals(1.0, result, 0.01); + } + + @Test + public void cosineIdenticalInput() { + // given identical input + + // when + Similarities s = new Similarities(); + double result = s.cosineSimilarity(input, input); + + // then + assertEquals(1.0, result, 0.01); + } + + @Test + public void pearsonIdenticalInput() { + // given identical input + + // when + Similarities s = new Similarities(); + double result = s.pearsonSimilarity(input, input, Collections.emptyMap()); + + // then + assertEquals(1.0, result, 0.01); + } + + @Test + public void euclideanIdenticalInput() { + // given identical input + + // when + Similarities s = new Similarities(); + double result = s.euclideanSimilarity(input, input); + + // then + assertEquals(1.0, result, 0.01); + } + + @Test + public void overlapIdenticalInput() { + // given identical input + + // when + Similarities s = new Similarities(); + double result = s.overlapSimilarity(input, input); + + // then + assertEquals(1.0, result, 0.01); + } +}