Skip to content

Commit

Permalink
[ML][Inference] Fixing pre-processor value handling and size estimate (
Browse files Browse the repository at this point in the history
…#49270) (#49489)

* [ML][Inference] Fixing pre-processor value handling and size estimate

* fixing npe
  • Loading branch information
benwtrent authored Nov 22, 2019
1 parent 35cc0e0 commit 276b6c6
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ public String getName() {

@Override
public void process(Map<String, Object> fields) {
String value = (String)fields.get(field);
Object value = fields.get(field);
if (value == null) {
return;
}
fields.put(featureName, frequencyMap.getOrDefault(value, 0.0));
fields.put(featureName, frequencyMap.getOrDefault(value.toString(), 0.0));
}

@Override
Expand Down Expand Up @@ -152,7 +152,8 @@ public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOf(field);
size += RamUsageEstimator.sizeOf(featureName);
size += RamUsageEstimator.sizeOfMap(frequencyMap);
// defSize:0 indicates that there is not a defined size. Finding the shallowSize of Double gives the best estimate
size += RamUsageEstimator.sizeOfMap(frequencyMap, 0);
return size;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ public String getName() {

@Override
public void process(Map<String, Object> fields) {
String value = (String)fields.get(field);
Object value = fields.get(field);
if (value == null) {
return;
}
hotMap.forEach((val, col) -> {
int encoding = value.equals(val) ? 1 : 0;
int encoding = value.toString().equals(val) ? 1 : 0;
fields.put(col, encoding);
});
}
Expand Down Expand Up @@ -134,7 +134,8 @@ public int hashCode() {
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOf(field);
size += RamUsageEstimator.sizeOfMap(hotMap);
// defSize:0 does not do much in this case as sizeOf(String) is a known quantity
size += RamUsageEstimator.sizeOfMap(hotMap, 0);
return size;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ public String getName() {

@Override
public void process(Map<String, Object> fields) {
String value = (String)fields.get(field);
Object value = fields.get(field);
if (value == null) {
return;
}
fields.put(featureName, meanMap.getOrDefault(value, defaultValue));
fields.put(featureName, meanMap.getOrDefault(value.toString(), defaultValue));
}

@Override
Expand Down Expand Up @@ -166,7 +166,8 @@ public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOf(field);
size += RamUsageEstimator.sizeOf(featureName);
size += RamUsageEstimator.sizeOfMap(meanMap);
// defSize:0 indicates that there is not a defined size. Finding the shallowSize of Double gives the best estimate
size += RamUsageEstimator.sizeOfMap(meanMap, 0);
return size;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -48,13 +47,14 @@ protected Writeable.Reader<FrequencyEncoding> instanceReader() {

public void testProcessWithFieldPresent() {
String field = "categorical";
List<String> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote");
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Function.identity(),
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
v -> randomDoubleBetween(0.0, 1.0, false)));
String encodedFeatureName = "encoded";
FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap);
String fieldValue = randomFrom(values);
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName, equalTo(valueMap.get(fieldValue)));
Object fieldValue = randomFrom(values);
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName,
equalTo(valueMap.get(fieldValue.toString())));
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
testProcess(encoding, fieldValues, matchers);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ protected Writeable.Reader<OneHotEncoding> instanceReader() {

public void testProcessWithFieldPresent() {
String field = "categorical";
List<String> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote");
Map<String, String> valueMap = values.stream().collect(Collectors.toMap(Function.identity(), v -> "Column_" + v));
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.0);
Map<String, String> valueMap = values.stream().collect(Collectors.toMap(Object::toString, v -> "Column_" + v.toString()));
OneHotEncoding encoding = new OneHotEncoding(field, valueMap);
String fieldValue = randomFrom(values);
Object fieldValue = randomFrom(values);
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);

Map<String, Matcher<? super Object>> matchers = values.stream().map(v -> "Column_" + v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ Map<String, Object> randomFieldValues() {
return fieldValues;
}

Map<String, Object> randomFieldValues(String categoricalField, String catigoricalValue) {
Map<String, Object> randomFieldValues(String categoricalField, Object categoricalValue) {
Map<String, Object> fieldValues = randomFieldValues();
fieldValues.put(categoricalField, catigoricalValue);
fieldValues.put(categoricalField, categoricalValue);
return fieldValues;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -51,14 +50,15 @@ protected Writeable.Reader<TargetMeanEncoding> instanceReader() {

public void testProcessWithFieldPresent() {
String field = "categorical";
List<String> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote");
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Function.identity(),
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.0);
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
v -> randomDoubleBetween(0.0, 1.0, false)));
String encodedFeatureName = "encoded";
Double defaultvalue = randomDouble();
TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue);
String fieldValue = randomFrom(values);
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName, equalTo(valueMap.get(fieldValue)));
Object fieldValue = randomFrom(values);
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName,
equalTo(valueMap.get(fieldValue.toString())));
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
testProcess(encoding, fieldValues, matchers);

Expand Down

0 comments on commit 276b6c6

Please sign in to comment.