Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML][Inference] Fixing pre-processor value handling and size estimate #49270

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ public String getName() {

@Override
public void process(Map<String, Object> fields) {
String value = (String)fields.get(field);
Object value = fields.get(field);
if (value == null) {
return;
}
fields.put(featureName, frequencyMap.getOrDefault(value, 0.0));
fields.put(featureName, frequencyMap.getOrDefault(value.toString(), 0.0));
}

@Override
Expand Down Expand Up @@ -152,7 +152,8 @@ public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOf(field);
size += RamUsageEstimator.sizeOf(featureName);
size += RamUsageEstimator.sizeOfMap(frequencyMap);
// defSize:0 indicates that there is not a defined size. Finding the shallowSize of Double gives the best estimate
size += RamUsageEstimator.sizeOfMap(frequencyMap, 0);
return size;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,12 @@ public String getName() {

@Override
public void process(Map<String, Object> fields) {
String value = (String)fields.get(field);
Object value = fields.get(field);
if (value == null) {
return;
}
hotMap.forEach((val, col) -> {
int encoding = value.equals(val) ? 1 : 0;
int encoding = value.toString().equals(val) ? 1 : 0;
fields.put(col, encoding);
});
}
Expand Down Expand Up @@ -134,7 +134,8 @@ public int hashCode() {
public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOf(field);
size += RamUsageEstimator.sizeOfMap(hotMap);
// defSize:0 does not do much in this case as sizeOf(String) is a known quantity
size += RamUsageEstimator.sizeOfMap(hotMap, 0);
return size;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ public String getName() {

@Override
public void process(Map<String, Object> fields) {
String value = (String)fields.get(field);
Object value = fields.get(field);
if (value == null) {
return;
}
fields.put(featureName, meanMap.getOrDefault(value, defaultValue));
fields.put(featureName, meanMap.getOrDefault(value.toString(), defaultValue));
}

@Override
Expand Down Expand Up @@ -166,7 +166,8 @@ public long ramBytesUsed() {
long size = SHALLOW_SIZE;
size += RamUsageEstimator.sizeOf(field);
size += RamUsageEstimator.sizeOf(featureName);
size += RamUsageEstimator.sizeOfMap(meanMap);
// defSize:0 indicates that there is not a defined size. Finding the shallowSize of Double gives the best estimate
size += RamUsageEstimator.sizeOfMap(meanMap, 0);
return size;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -48,13 +47,14 @@ protected Writeable.Reader<FrequencyEncoding> instanceReader() {

public void testProcessWithFieldPresent() {
String field = "categorical";
List<String> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote");
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Function.identity(),
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.5);
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
v -> randomDoubleBetween(0.0, 1.0, false)));
String encodedFeatureName = "encoded";
FrequencyEncoding encoding = new FrequencyEncoding(field, encodedFeatureName, valueMap);
String fieldValue = randomFrom(values);
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName, equalTo(valueMap.get(fieldValue)));
Object fieldValue = randomFrom(values);
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName,
equalTo(valueMap.get(fieldValue.toString())));
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
testProcess(encoding, fieldValues, matchers);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ protected Writeable.Reader<OneHotEncoding> instanceReader() {

public void testProcessWithFieldPresent() {
String field = "categorical";
List<String> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote");
Map<String, String> valueMap = values.stream().collect(Collectors.toMap(Function.identity(), v -> "Column_" + v));
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.0);
Map<String, String> valueMap = values.stream().collect(Collectors.toMap(Object::toString, v -> "Column_" + v.toString()));
OneHotEncoding encoding = new OneHotEncoding(field, valueMap);
String fieldValue = randomFrom(values);
Object fieldValue = randomFrom(values);
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);

Map<String, Matcher<? super Object>> matchers = values.stream().map(v -> "Column_" + v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ Map<String, Object> randomFieldValues() {
return fieldValues;
}

Map<String, Object> randomFieldValues(String categoricalField, String catigoricalValue) {
Map<String, Object> randomFieldValues(String categoricalField, Object categoricalValue) {
Map<String, Object> fieldValues = randomFieldValues();
fieldValues.put(categoricalField, catigoricalValue);
fieldValues.put(categoricalField, categoricalValue);
return fieldValues;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.hamcrest.Matchers.equalTo;
Expand Down Expand Up @@ -51,14 +50,15 @@ protected Writeable.Reader<TargetMeanEncoding> instanceReader() {

public void testProcessWithFieldPresent() {
String field = "categorical";
List<String> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote");
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Function.identity(),
List<Object> values = Arrays.asList("foo", "bar", "foobar", "baz", "farequote", 1.0);
Map<String, Double> valueMap = values.stream().collect(Collectors.toMap(Object::toString,
v -> randomDoubleBetween(0.0, 1.0, false)));
String encodedFeatureName = "encoded";
Double defaultvalue = randomDouble();
TargetMeanEncoding encoding = new TargetMeanEncoding(field, encodedFeatureName, valueMap, defaultvalue);
String fieldValue = randomFrom(values);
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName, equalTo(valueMap.get(fieldValue)));
Object fieldValue = randomFrom(values);
Map<String, Matcher<? super Object>> matchers = Collections.singletonMap(encodedFeatureName,
equalTo(valueMap.get(fieldValue.toString())));
Map<String, Object> fieldValues = randomFieldValues(field, fieldValue);
testProcess(encoding, fieldValues, matchers);

Expand Down