From f4b7b8f7eae016c3187c0ed6bc9dcb0abf61eae4 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 18 Nov 2019 15:52:45 -0500 Subject: [PATCH 1/2] [ML][Inference] Fixing pre-processor value handling and size estimate --- .../ml/inference/preprocessing/FrequencyEncoding.java | 5 +++-- .../ml/inference/preprocessing/OneHotEncoding.java | 5 +++-- .../ml/inference/preprocessing/TargetMeanEncoding.java | 5 +++-- .../preprocessing/FrequencyEncodingTests.java | 10 +++++----- .../inference/preprocessing/OneHotEncodingTests.java | 6 +++--- .../ml/inference/preprocessing/PreProcessingTests.java | 4 ++-- .../preprocessing/TargetMeanEncodingTests.java | 10 +++++----- 7 files changed, 24 insertions(+), 21 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java index cea99d3edc8f6..444c971edc16f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java @@ -103,7 +103,7 @@ public String getName() { @Override public void process(Map fields) { - String value = (String)fields.get(field); + String value = fields.get(field).toString(); if (value == null) { return; } @@ -152,7 +152,8 @@ public long ramBytesUsed() { long size = SHALLOW_SIZE; size += RamUsageEstimator.sizeOf(field); size += RamUsageEstimator.sizeOf(featureName); - size += RamUsageEstimator.sizeOfMap(frequencyMap); + // defSize:0 indicates that there is not a defined size. Finding the shallowSize of Double gives the best estimate + size += RamUsageEstimator.sizeOfMap(frequencyMap, 0); return size; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java index 9784ed8cbe7aa..a22ca7ed20a3b 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java @@ -86,7 +86,7 @@ public String getName() { @Override public void process(Map fields) { - String value = (String)fields.get(field); + String value = fields.get(field).toString(); if (value == null) { return; } @@ -134,7 +134,8 @@ public int hashCode() { public long ramBytesUsed() { long size = SHALLOW_SIZE; size += RamUsageEstimator.sizeOf(field); - size += RamUsageEstimator.sizeOfMap(hotMap); + // defSize:0 does not do much in this case as sizeOf(String) is a known quantity + size += RamUsageEstimator.sizeOfMap(hotMap, 0); return size; } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java index 914b43f98e967..73a86967e6e98 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java @@ -114,7 +114,7 @@ public String getName() { @Override public void process(Map fields) { - String value = (String)fields.get(field); + String value = fields.get(field).toString(); if (value == null) { return; } @@ -166,7 +166,8 @@ public long ramBytesUsed() { long size = SHALLOW_SIZE; size += RamUsageEstimator.sizeOf(field); size += RamUsageEstimator.sizeOf(featureName); - size += RamUsageEstimator.sizeOfMap(meanMap); + // defSize:0 indicates that there is not a defined size. Finding the shallowSize of Double gives the best estimate + size += RamUsageEstimator.sizeOfMap(meanMap, 0); return size; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java index 72047178e9f54..4c0497fa409f9 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncodingTests.java @@ -15,7 +15,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Function; import java.util.stream.Collectors; import static org.hamcrest.Matchers.equalTo; @@ -48,13 +47,14 @@ protected Writeable.Reader instanceReader() { public void testProcessWithFieldPresent() { String field = "categorical"; - List values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote"); - Map valueMap = values.stream().collect(Collectors.toMap(Function.identity(), + List values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5); + Map valueMap = values.stream().collect(Collectors.toMap(Object::toString, v -> randomDoubleBetween(0.0, 1.0, false))); String encodedFeatureName = "encoded"; FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap); - String fieldValue = randomFrom(values); - Map> matchers = Collections.singletonMap(encodedFeatureName, equalTo(valueMap.get(fieldValue))); + Object fieldValue = randomFrom(values); + Map> matchers = Collections.singletonMap(encodedFeatureName, + equalTo(valueMap.get(fieldValue.toString()))); Map fieldValues = randomFieldValues(field, fieldValue); testProcess(encoding, fieldValues, matchers); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java index f0627719ec47c..8b35b77b5a69c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncodingTests.java @@ -47,10 +47,10 @@ protected Writeable.Reader instanceReader() { public void testProcessWithFieldPresent() { String field = "categorical"; - List values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote"); - Map valueMap = values.stream().collect(Collectors.toMap(Function.identity(), v -> "Column_" + v)); + List values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.0); + Map valueMap = values.stream().collect(Collectors.toMap(Object::toString, v -> "Column_" + v.toString())); OneHotEncoding encoding = new OneHotEncoding(field, valueMap); - String fieldValue = randomFrom(values); + Object fieldValue = randomFrom(values); Map fieldValues = randomFieldValues(field, fieldValue); Map> matchers = values.stream().map(v -> "Column_" + v) diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessingTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessingTests.java index 4301b09c5ece7..c4e8b879bcdd2 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessingTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessingTests.java @@ -58,9 +58,9 @@ Map randomFieldValues() { return fieldValues; } - Map randomFieldValues(String categoricalField, String catigoricalValue) { + Map randomFieldValues(String categoricalField, Object categoricalValue) { Map fieldValues = randomFieldValues(); - fieldValues.put(categoricalField, catigoricalValue); + fieldValues.put(categoricalField, categoricalValue); return fieldValues; } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java index d86d9e09f0238..e2aaf1e1256c6 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncodingTests.java @@ -15,7 +15,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.function.Function; import java.util.stream.Collectors; import static org.hamcrest.Matchers.equalTo; @@ -51,14 +50,15 @@ protected Writeable.Reader instanceReader() { public void testProcessWithFieldPresent() { String field = "categorical"; - List values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote"); - Map valueMap = values.stream().collect(Collectors.toMap(Function.identity(), + List values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.0); + Map valueMap = values.stream().collect(Collectors.toMap(Object::toString, v -> randomDoubleBetween(0.0, 1.0, false))); String encodedFeatureName = "encoded"; Double defaultvalue = randomDouble(); TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue); - String fieldValue = randomFrom(values); - Map> matchers = Collections.singletonMap(encodedFeatureName, equalTo(valueMap.get(fieldValue))); + Object fieldValue = randomFrom(values); + Map> matchers = Collections.singletonMap(encodedFeatureName, + equalTo(valueMap.get(fieldValue.toString()))); Map fieldValues = randomFieldValues(field, fieldValue); testProcess(encoding, fieldValues, matchers); From 95bce2484102c23e722f421b3974d07097162682 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 19 Nov 2019 13:19:11 -0500 Subject: [PATCH 2/2] fixing npe --- .../core/ml/inference/preprocessing/FrequencyEncoding.java | 4 ++-- .../xpack/core/ml/inference/preprocessing/OneHotEncoding.java | 4 ++-- .../core/ml/inference/preprocessing/TargetMeanEncoding.java | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java index 444c971edc16f..ed693460edcc7 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java @@ -103,11 +103,11 @@ public String getName() { @Override public void process(Map fields) { - String value = fields.get(field).toString(); + Object value = fields.get(field); if (value == null) { return; } - fields.put(featureName, frequencyMap.getOrDefault(value, 0.0)); + fields.put(featureName, frequencyMap.getOrDefault(value.toString(), 0.0)); } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java index a22ca7ed20a3b..a4924a277c0ad 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java @@ -86,12 +86,12 @@ public String getName() { @Override public void process(Map fields) { - String value = fields.get(field).toString(); + Object value = fields.get(field); if (value == null) { return; } hotMap.forEach((val, col) -> { - int encoding = value.equals(val) ? 1 : 0; + int encoding = value.toString().equals(val) ? 1 : 0; fields.put(col, encoding); }); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java index 73a86967e6e98..8276fc2c8fefb 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java @@ -114,11 +114,11 @@ public String getName() { @Override public void process(Map fields) { - String value = fields.get(field).toString(); + Object value = fields.get(field); if (value == null) { return; } - fields.put(featureName, meanMap.getOrDefault(value, defaultValue)); + fields.put(featureName, meanMap.getOrDefault(value.toString(), defaultValue)); } @Override