Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Data Frame Analytics: Fix feature importance #61761

Merged
merged 11 commits into from
Apr 4, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ interface OutlierAnalysis {
interface Regression {
dependent_variable: string;
training_percent?: number;
num_top_feature_importance_values?: number;
prediction_field_name?: string;
}
export interface RegressionAnalysis {
Expand All @@ -44,6 +45,7 @@ interface Classification {
dependent_variable: string;
training_percent?: number;
num_top_classes?: string;
num_top_feature_importance_values?: number;
prediction_field_name?: string;
}
export interface ClassificationAnalysis {
Expand Down Expand Up @@ -186,6 +188,23 @@ export const getPredictionFieldName = (analysis: AnalysisConfig) => {
return predictionFieldName;
};

export const getNumTopFeatureImportanceValues = (analysis: AnalysisConfig) => {
// If undefined will be defaulted to dependent_variable when config is created
let numTopFeatureImportanceValues = 0;
if (
isRegressionAnalysis(analysis) &&
analysis.regression.num_top_feature_importance_values !== undefined
) {
numTopFeatureImportanceValues = analysis.regression.num_top_feature_importance_values;
} else if (
isClassificationAnalysis(analysis) &&
analysis.classification.num_top_feature_importance_values !== undefined
) {
numTopFeatureImportanceValues = analysis.classification.num_top_feature_importance_values;
}
return numTopFeatureImportanceValues;
};

export const getPredictedFieldName = (
resultsField: string,
analysis: AnalysisConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import { getNestedProperty } from '../../util/object_utils';
import {
DataFrameAnalyticsConfig,
getNumTopFeatureImportanceValues,
getPredictedFieldName,
getDependentVar,
getPredictionFieldName,
} from './analytics';
import { Field } from '../../../../common/types/fields';
import { ES_FIELD_TYPES } from '../../../../../../../src/plugins/data/public';
import { ES_FIELD_TYPES, KBN_FIELD_TYPES } from '../../../../../../../src/plugins/data/public';
import { newJobCapsService } from '../../services/new_job_capabilities_service';

export type EsId = string;
Expand Down Expand Up @@ -253,6 +254,7 @@ export const getDefaultFieldsFromJobCaps = (
const dependentVariable = getDependentVar(jobConfig.analysis);
const type = newJobCapsService.getFieldById(dependentVariable)?.type;
const predictionFieldName = getPredictionFieldName(jobConfig.analysis);
const numTopFeatureImportanceValues = getNumTopFeatureImportanceValues(jobConfig.analysis);
// default is 'ml'
const resultsField = jobConfig.dest.results_field;

Expand All @@ -261,6 +263,18 @@ export const getDefaultFieldsFromJobCaps = (
predictionFieldName ? predictionFieldName : defaultPredictionField
}`;

const featureImportanceFields = [];

if (numTopFeatureImportanceValues > 0) {
featureImportanceFields.push(
peteharverson marked this conversation as resolved.
Show resolved Hide resolved
...fields.map(d => ({
id: `${resultsField}.feature_importance.${d.id}`,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This highlights that we need to use EuiDataGrid for classification and regression results, as I can't actually tell which columns are which!

image

name: `${resultsField}.feature_importance.${d.name}`,
type: KBN_FIELD_TYPES.NUMBER,
}))
);
}

const allFields: any = [
{
id: `${resultsField}.is_training`,
Expand All @@ -269,11 +283,14 @@ export const getDefaultFieldsFromJobCaps = (
},
{ id: predictedField, name: predictedField, type },
...fields,
...featureImportanceFields,
].sort(({ name: a }, { name: b }) => sortRegressionResultsFields(a, b, jobConfig));

let selectedFields = allFields
.slice(0, DEFAULT_REGRESSION_COLUMNS * 2)
.filter((field: any) => field.name === predictedField || !field.name.includes('.keyword'));
let selectedFields = allFields.filter(
(field: any) =>
field.name === predictedField ||
(!field.name.includes('.keyword') && !field.name.includes('.feature_importance.'))
);

if (selectedFields.length > DEFAULT_REGRESSION_COLUMNS) {
selectedFields = selectedFields.slice(0, DEFAULT_REGRESSION_COLUMNS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ describe('Analytics job clone action', () => {
classification: {
dependent_variable: 'y',
num_top_classes: 2,
num_top_feature_importance_values: 4,
prediction_field_name: 'y_prediction',
training_percent: 2,
randomize_seed: 6233212276062807000,
Expand Down Expand Up @@ -90,6 +91,7 @@ describe('Analytics job clone action', () => {
prediction_field_name: 'stab_prediction',
training_percent: 20,
randomize_seed: -2228827740028660200,
num_top_feature_importance_values: 4,
},
},
analyzed_fields: {
Expand Down Expand Up @@ -120,6 +122,7 @@ describe('Analytics job clone action', () => {
classification: {
dependent_variable: 'y',
num_top_classes: 2,
num_top_feature_importance_values: 4,
prediction_field_name: 'y_prediction',
training_percent: 2,
randomize_seed: 6233212276062807000,
Expand Down Expand Up @@ -188,6 +191,7 @@ describe('Analytics job clone action', () => {
prediction_field_name: 'stab_prediction',
training_percent: 20,
randomize_seed: -2228827740028660200,
num_top_feature_importance_values: 4,
},
},
analyzed_fields: {
Expand Down Expand Up @@ -218,6 +222,7 @@ describe('Analytics job clone action', () => {
dependent_variable: 'y',
training_percent: 71,
max_trees: 1500,
num_top_feature_importance_values: 4,
},
},
model_memory_limit: '400mb',
Expand All @@ -243,6 +248,7 @@ describe('Analytics job clone action', () => {
dependent_variable: 'y',
training_percent: 71,
maximum_number_trees: 1500,
num_top_feature_importance_values: 4,
},
},
model_memory_limit: '400mb',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ const getAnalyticsJobMeta = (config: CloneDataFrameAnalyticsConfig): AnalyticsJo
},
num_top_feature_importance_values: {
optional: true,
defaultValue: 0,
peteharverson marked this conversation as resolved.
Show resolved Hide resolved
formKey: 'numTopFeatureImportanceValues',
},
class_assignment_objective: {
optional: true,
Expand Down Expand Up @@ -164,6 +166,8 @@ const getAnalyticsJobMeta = (config: CloneDataFrameAnalyticsConfig): AnalyticsJo
},
num_top_feature_importance_values: {
optional: true,
defaultValue: 0,
peteharverson marked this conversation as resolved.
Show resolved Hide resolved
formKey: 'numTopFeatureImportanceValues',
},
randomize_seed: {
optional: true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ export const CreateAnalyticsForm: FC<CreateAnalyticsFormProps> = ({ actions, sta
maxDistinctValuesError,
modelMemoryLimit,
modelMemoryLimitValidationResult,
numTopFeatureImportanceValues,
previousJobType,
previousSourceIndex,
sourceIndex,
Expand Down Expand Up @@ -643,6 +644,31 @@ export const CreateAnalyticsForm: FC<CreateAnalyticsFormProps> = ({ actions, sta
data-test-subj="mlAnalyticsCreateJobFlyoutTrainingPercentSlider"
/>
</EuiFormRow>
{/* num_top_feature_importance_values */}
<EuiFormRow
label={i18n.translate(
'xpack.ml.dataframe.analytics.create.numTopFeatureImportanceValuesLabel',
{
defaultMessage: 'Max. number of feature importance values per document',
peteharverson marked this conversation as resolved.
Show resolved Hide resolved
}
)}
>
<EuiFieldText
disabled={false}
value={numTopFeatureImportanceValues}
onChange={e =>
setFormState({ numTopFeatureImportanceValues: parseInt(e.target.value, 10) })
peteharverson marked this conversation as resolved.
Show resolved Hide resolved
}
aria-label={i18n.translate(
'xpack.ml.dataframe.analytics.create.numTopFeatureImportanceValuesInputAriaLabel',
{
defaultMessage: 'Max. number of feature importance values per document.',
peteharverson marked this conversation as resolved.
Show resolved Hide resolved
}
)}
isInvalid={!destinationIndexNameEmpty && !destinationIndexNameValid}
data-test-subj="mlAnalyticsCreateJobFlyoutnumTopFeatureImportanceValuesInput"
/>
</EuiFormRow>
</Fragment>
)}
<EuiFormRow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ export interface State {
modelMemoryLimit: string | undefined;
modelMemoryLimitUnitValid: boolean;
modelMemoryLimitValidationResult: any;
numTopFeatureImportanceValues: number | undefined;
previousJobType: null | AnalyticsJobType;
previousSourceIndex: EsIndexName | undefined;
sourceIndex: EsIndexName;
Expand Down Expand Up @@ -123,6 +124,7 @@ export const getInitialState = (): State => ({
modelMemoryLimit: undefined,
modelMemoryLimitUnitValid: true,
modelMemoryLimitValidationResult: null,
numTopFeatureImportanceValues: 0,
peteharverson marked this conversation as resolved.
Show resolved Hide resolved
previousJobType: null,
previousSourceIndex: undefined,
sourceIndex: '',
Expand Down Expand Up @@ -182,6 +184,7 @@ export const getJobConfigFromFormState = (
jobConfig.analysis = {
[formState.jobType]: {
dependent_variable: formState.dependentVariable,
num_top_feature_importance_values: formState.numTopFeatureImportanceValues,
training_percent: formState.trainingPercent,
},
};
Expand Down Expand Up @@ -216,6 +219,7 @@ export function getCloneFormStateFromJobConfig(
const analysisConfig = analyticsJobConfig.analysis[jobType];

resultState.dependentVariable = analysisConfig.dependent_variable;
resultState.numTopFeatureImportanceValues = analysisConfig.num_top_feature_importance_values;
resultState.trainingPercent = analysisConfig.training_percent;
}

Expand Down