Skip to content
This repository has been archived by the owner on Apr 22, 2020. It is now read-only.

Commit

Permalink
Degree cutoff skip values (#880)
Browse files Browse the repository at this point in the history
* consider skip values when checking degree cut off

* d has no links so it gets filtered out by the degree cut off

* d has no links so it gets filtered out by the degree cut off

* typo
  • Loading branch information
mneedham authored Apr 10, 2019
1 parent 2c41049 commit 5229765
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,26 @@ static int[] indexesFor(long[] inputIds, ProcedureConfiguration configuration, S
}
}


static List<Number> extractValues(Object rawValues) {
if (rawValues == null) {
return Collections.emptyList();
}

List<Number> valueList = new ArrayList<>();
if (rawValues instanceof long[]) {
long[] values = (long[]) rawValues;
for (long value : values) {
valueList.add(value);
}
} else if (rawValues instanceof double[]) {
double[] values = (double[]) rawValues;
for (double value : values) {
valueList.add(value);
}
} else {
valueList = (List<Number>) rawValues;
}
return valueList;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ CategoricalInput[] prepareCategories(List<Map<String, Object>> data, long degree
CategoricalInput[] ids = new CategoricalInput[data.size()];
int idx = 0;
for (Map<String, Object> row : data) {
List<Number> targetIds = extractValues(row.get("categories"));
List<Number> targetIds = SimilarityInput.extractValues(row.get("categories"));
int size = targetIds.size();
if (size > degreeCutoff) {
long[] targets = new long[size];
Expand All @@ -156,32 +156,14 @@ WeightedInput[] prepareWeights(Object rawData, ProcedureConfiguration configurat
return prepareSparseWeights(api, (String) rawData, skipValue, configuration);
} else {
List<Map<String, Object>> data = (List<Map<String, Object>>) rawData;
return preparseDenseWeights(data, getDegreeCutoff(configuration), skipValue);
return WeightedInput.prepareDenseWeights(data, getDegreeCutoff(configuration), skipValue);
}
}

Double readSkipValue(ProcedureConfiguration configuration) {
return configuration.get("skipValue", Double.NaN);
}

private WeightedInput[] preparseDenseWeights(List<Map<String, Object>> data, long degreeCutoff, Double skipValue) {
WeightedInput[] inputs = new WeightedInput[data.size()];
int idx = 0;
for (Map<String, Object> row : data) {

List<Number> weightList = extractValues(row.get("weights"));

int size = weightList.size();
if (size > degreeCutoff) {
double[] weights = Weights.buildWeights(weightList);
inputs[idx++] = skipValue == null ? WeightedInput.dense((Long) row.get("item"), weights) : WeightedInput.dense((Long) row.get("item"), weights, skipValue);
}
}
if (idx != inputs.length) inputs = Arrays.copyOf(inputs, idx);
Arrays.sort(inputs);
return inputs;
}

private WeightedInput[] prepareSparseWeights(GraphDatabaseAPI api, String query, Double skipValue, ProcedureConfiguration configuration) throws Exception {
Map<String, Object> params = configuration.getParams();
Long degreeCutoff = getDegreeCutoff(configuration);
Expand Down Expand Up @@ -230,28 +212,6 @@ private WeightedInput[] prepareSparseWeights(GraphDatabaseAPI api, String query,
return inputs;
}

private List<Number> extractValues(Object rawValues) {
if (rawValues == null) {
return Collections.emptyList();
}

List<Number> valueList = new ArrayList<>();
if (rawValues instanceof long[]) {
long[] values = (long[]) rawValues;
for (long value : values) {
valueList.add(value);
}
} else if (rawValues instanceof double[]) {
double[] values = (double[]) rawValues;
for (double value : values) {
valueList.add(value);
}
} else {
valueList = (List<Number>) rawValues;
}
return valueList;
}

int getTopK(ProcedureConfiguration configuration) {
return configuration.getInt("topK", 0);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

import org.neo4j.graphalgo.core.utils.Intersections;

import java.util.Arrays;
import java.util.List;
import java.util.Map;

class WeightedInput implements Comparable<WeightedInput>, SimilarityInput {
private final long id;
private int itemCount;
Expand Down Expand Up @@ -62,6 +66,32 @@ public static WeightedInput dense(long id, double[] weights) {
return new WeightedInput(id, weights);
}

static WeightedInput[] prepareDenseWeights(List<Map<String, Object>> data, long degreeCutoff, Double skipValue) {
WeightedInput[] inputs = new WeightedInput[data.size()];
int idx = 0;

boolean skipAnything = skipValue != null;
boolean skipNan = skipAnything && Double.isNaN(skipValue);

for (Map<String, Object> row : data) {
List<Number> weightList = SimilarityInput.extractValues(row.get("weights"));

long weightsSize = skipAnything ? skipSize(skipValue, skipNan, weightList) : weightList.size();

if (weightsSize > degreeCutoff) {
double[] weights = Weights.buildWeights(weightList);
inputs[idx++] = skipValue == null ? dense((Long) row.get("item"), weights) : dense((Long) row.get("item"), weights, skipValue);
}
}
if (idx != inputs.length) inputs = Arrays.copyOf(inputs, idx);
Arrays.sort(inputs);
return inputs;
}

private static long skipSize(Double skipValue, boolean skipNan, List<Number> weightList) {
return weightList.stream().filter(value -> !Intersections.shouldSkip(value.doubleValue(), skipValue, skipNan)).count();
}

public int compareTo(WeightedInput o) {
return Long.compare(id, o.id);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,50 @@
package org.neo4j.graphalgo.similarity;

import org.junit.Test;
import org.neo4j.helpers.collection.MapUtil;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

import static junit.framework.TestCase.assertEquals;
import static junit.framework.TestCase.assertNull;

public class WeightedInputTest {
@Test
public void degreeCutoffBasedOnSkipValue() {
List<Map<String, Object>> data = new ArrayList<>();
data.add(MapUtil.map("item", 1L,"weights", Arrays.asList(2.0, 3.0, 4.0)));
data.add(MapUtil.map("item", 2L,"weights", Arrays.asList(2.0, 3.0, Double.NaN)));

WeightedInput[] weightedInputs = WeightedInput.prepareDenseWeights(data, 2L, Double.NaN);

assertEquals(1, weightedInputs.length);
}

@Test
public void degreeCutoffWithoutSkipValue() {
List<Map<String, Object>> data = new ArrayList<>();
data.add(MapUtil.map("item", 1L,"weights", Arrays.asList(2.0, 3.0, 4.0)));
data.add(MapUtil.map("item", 2L,"weights", Arrays.asList(2.0, 3.0, Double.NaN)));

WeightedInput[] weightedInputs = WeightedInput.prepareDenseWeights(data, 2L, null);

assertEquals(2, weightedInputs.length);
}

@Test
public void degreeCutoffWithNumericSkipValue() {
List<Map<String, Object>> data = new ArrayList<>();
data.add(MapUtil.map("item", 1L,"weights", Arrays.asList(2.0, 3.0, 4.0)));
data.add(MapUtil.map("item", 2L,"weights", Arrays.asList(2.0, 3.0, 5.0)));

WeightedInput[] weightedInputs = WeightedInput.prepareDenseWeights(data, 2L, 5.0);

assertEquals(1, weightedInputs.length);
}

@Test
public void pearsonNoCompression() {
double[] weights1 = new double[]{1, 2, 3, 4, 4, 4, 4, 5, 6};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ public static double pearsonSkip(double[] vector1, double[] vector2, int len, do
return Double.isNaN(result) ? 0 : result;
}

private static boolean shouldSkip(double weight, double skipValue, boolean skipNan) {
public static boolean shouldSkip(double weight, double skipValue, boolean skipNan) {
return weight == skipValue || (skipNan && Double.isNaN(weight));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,7 @@ public void cosineSkipStreamTest() {
assertTrue(results.hasNext());
assert01Skip(results.next());
assert02Skip(results.next());
assert03Skip(results.next());
assert12Skip(results.next());
assert13Skip(results.next());
assert23Skip(results.next());
assertFalse(results.hasNext());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,7 @@ public void eucideanSkipStreamTest() {
assertTrue(results.hasNext());
assert01Skip(results.next());
assert02Skip(results.next());
assert03Skip(results.next());
assert12Skip(results.next());
assert13Skip(results.next());
assert23Skip(results.next());
assertFalse(results.hasNext());
}

Expand Down

0 comments on commit 5229765

Please sign in to comment.