Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Data Frame Analytics: Fix feature importance #61761

Merged
merged 11 commits into from
Apr 4, 2020
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ interface OutlierAnalysis {
interface Regression {
dependent_variable: string;
training_percent?: number;
num_top_feature_importance_values?: number;
prediction_field_name?: string;
}
export interface RegressionAnalysis {
Expand All @@ -44,6 +45,7 @@ interface Classification {
dependent_variable: string;
training_percent?: number;
num_top_classes?: string;
num_top_feature_importance_values?: number;
prediction_field_name?: string;
}
export interface ClassificationAnalysis {
Expand All @@ -65,6 +67,8 @@ export const SEARCH_SIZE = 1000;
export const TRAINING_PERCENT_MIN = 1;
export const TRAINING_PERCENT_MAX = 100;

export const NUM_TOP_FEATURE_IMPORTANCE_VALUES_MIN = 0;

export const defaultSearchQuery = {
match_all: {},
};
Expand Down Expand Up @@ -152,7 +156,7 @@ type AnalysisConfig =
| ClassificationAnalysis
| GenericAnalysis;

export const getAnalysisType = (analysis: AnalysisConfig) => {
export const getAnalysisType = (analysis: AnalysisConfig): string => {
const keys = Object.keys(analysis);

if (keys.length === 1) {
Expand All @@ -162,7 +166,11 @@ export const getAnalysisType = (analysis: AnalysisConfig) => {
return 'unknown';
};

export const getDependentVar = (analysis: AnalysisConfig) => {
export const getDependentVar = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['dependent_variable']
| ClassificationAnalysis['classification']['dependent_variable'] => {
let depVar = '';

if (isRegressionAnalysis(analysis)) {
Expand All @@ -175,7 +183,11 @@ export const getDependentVar = (analysis: AnalysisConfig) => {
return depVar;
};

export const getTrainingPercent = (analysis: AnalysisConfig) => {
export const getTrainingPercent = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['training_percent']
| ClassificationAnalysis['classification']['training_percent'] => {
let trainingPercent;

if (isRegressionAnalysis(analysis)) {
Expand All @@ -188,7 +200,11 @@ export const getTrainingPercent = (analysis: AnalysisConfig) => {
return trainingPercent;
};

export const getPredictionFieldName = (analysis: AnalysisConfig) => {
export const getPredictionFieldName = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['prediction_field_name']
| ClassificationAnalysis['classification']['prediction_field_name'] => {
// If undefined will be defaulted to dependent_variable when config is created
let predictionFieldName;
if (isRegressionAnalysis(analysis) && analysis.regression.prediction_field_name !== undefined) {
Expand All @@ -202,6 +218,26 @@ export const getPredictionFieldName = (analysis: AnalysisConfig) => {
return predictionFieldName;
};

export const getNumTopFeatureImportanceValues = (
analysis: AnalysisConfig
):
| RegressionAnalysis['regression']['num_top_feature_importance_values']
| ClassificationAnalysis['classification']['num_top_feature_importance_values'] => {
let numTopFeatureImportanceValues;
if (
isRegressionAnalysis(analysis) &&
analysis.regression.num_top_feature_importance_values !== undefined
) {
numTopFeatureImportanceValues = analysis.regression.num_top_feature_importance_values;
} else if (
isClassificationAnalysis(analysis) &&
analysis.classification.num_top_feature_importance_values !== undefined
) {
numTopFeatureImportanceValues = analysis.classification.num_top_feature_importance_values;
}
return numTopFeatureImportanceValues;
};

export const getPredictedFieldName = (
resultsField: string,
analysis: AnalysisConfig,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import { getNestedProperty } from '../../util/object_utils';
import {
DataFrameAnalyticsConfig,
getNumTopFeatureImportanceValues,
getPredictedFieldName,
getDependentVar,
getPredictionFieldName,
} from './analytics';
import { Field } from '../../../../common/types/fields';
import { ES_FIELD_TYPES } from '../../../../../../../src/plugins/data/public';
import { ES_FIELD_TYPES, KBN_FIELD_TYPES } from '../../../../../../../src/plugins/data/public';
import { newJobCapsService } from '../../services/new_job_capabilities_service';

export type EsId = string;
Expand Down Expand Up @@ -254,14 +255,28 @@ export const getDefaultFieldsFromJobCaps = (
const dependentVariable = getDependentVar(jobConfig.analysis);
const type = newJobCapsService.getFieldById(dependentVariable)?.type;
const predictionFieldName = getPredictionFieldName(jobConfig.analysis);
const numTopFeatureImportanceValues = getNumTopFeatureImportanceValues(jobConfig.analysis);
// default is 'ml'
const resultsField = jobConfig.dest.results_field;

const defaultPredictionField = `${dependentVariable}_prediction`;
const predictedField = `${resultsField}.${
predictionFieldName ? predictionFieldName : defaultPredictionField
}`;
// Only need to add these first two fields if we didn't use dest index pattern to get the fields

const featureImportanceFields = [];

if ((numTopFeatureImportanceValues ?? 0) > 0) {
featureImportanceFields.push(
peteharverson marked this conversation as resolved.
Show resolved Hide resolved
...fields.map(d => ({
id: `${resultsField}.feature_importance.${d.id}`,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This highlights that we need to use EuiDataGrid for classification and regression results, as I can't actually tell which columns are which!

image

name: `${resultsField}.feature_importance.${d.name}`,
type: KBN_FIELD_TYPES.NUMBER,
}))
);
}

// Only need to add these fields if we didn't use dest index pattern to get the fields
const allFields: any =
needsDestIndexFields === true
? [
Expand All @@ -271,16 +286,20 @@ export const getDefaultFieldsFromJobCaps = (
type: ES_FIELD_TYPES.BOOLEAN,
},
{ id: predictedField, name: predictedField, type },
...featureImportanceFields,
]
: [];

allFields.push(...fields);
// @ts-ignore
allFields.sort(({ name: a }, { name: b }) => sortRegressionResultsFields(a, b, jobConfig));

let selectedFields = allFields
.slice(0, DEFAULT_REGRESSION_COLUMNS * 2)
.filter((field: any) => field.name === predictedField || !field.name.includes('.keyword'));
allFields.sort(({ name: a }: { name: string }, { name: b }: { name: string }) =>
sortRegressionResultsFields(a, b, jobConfig)
);

let selectedFields = allFields.filter(
(field: any) =>
field.name === predictedField ||
(!field.name.includes('.keyword') && !field.name.includes('.feature_importance.'))
);

if (selectedFields.length > DEFAULT_REGRESSION_COLUMNS) {
selectedFields = selectedFields.slice(0, DEFAULT_REGRESSION_COLUMNS);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ describe('Analytics job clone action', () => {
classification: {
dependent_variable: 'y',
num_top_classes: 2,
num_top_feature_importance_values: 4,
prediction_field_name: 'y_prediction',
training_percent: 2,
randomize_seed: 6233212276062807000,
Expand Down Expand Up @@ -90,6 +91,7 @@ describe('Analytics job clone action', () => {
prediction_field_name: 'stab_prediction',
training_percent: 20,
randomize_seed: -2228827740028660200,
num_top_feature_importance_values: 4,
},
},
analyzed_fields: {
Expand Down Expand Up @@ -120,6 +122,7 @@ describe('Analytics job clone action', () => {
classification: {
dependent_variable: 'y',
num_top_classes: 2,
num_top_feature_importance_values: 4,
prediction_field_name: 'y_prediction',
training_percent: 2,
randomize_seed: 6233212276062807000,
Expand Down Expand Up @@ -188,6 +191,7 @@ describe('Analytics job clone action', () => {
prediction_field_name: 'stab_prediction',
training_percent: 20,
randomize_seed: -2228827740028660200,
num_top_feature_importance_values: 4,
},
},
analyzed_fields: {
Expand Down Expand Up @@ -218,6 +222,7 @@ describe('Analytics job clone action', () => {
dependent_variable: 'y',
training_percent: 71,
max_trees: 1500,
num_top_feature_importance_values: 4,
},
},
model_memory_limit: '400mb',
Expand All @@ -243,6 +248,7 @@ describe('Analytics job clone action', () => {
dependent_variable: 'y',
training_percent: 71,
maximum_number_trees: 1500,
num_top_feature_importance_values: 4,
},
},
model_memory_limit: '400mb',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ import { i18n } from '@kbn/i18n';
import { DeepReadonly } from '../../../../../../../common/types/common';
import { DataFrameAnalyticsConfig, isOutlierAnalysis } from '../../../../common';
import { isClassificationAnalysis, isRegressionAnalysis } from '../../../../common/analytics';
import { CreateAnalyticsFormProps } from '../../hooks/use_create_analytics_form';
import {
CreateAnalyticsFormProps,
DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES,
} from '../../hooks/use_create_analytics_form';
import { State } from '../../hooks/use_create_analytics_form/state';
import { DataFrameAnalyticsListRow } from './common';

Expand Down Expand Up @@ -97,6 +100,8 @@ const getAnalyticsJobMeta = (config: CloneDataFrameAnalyticsConfig): AnalyticsJo
},
num_top_feature_importance_values: {
optional: true,
defaultValue: DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES,
formKey: 'numTopFeatureImportanceValues',
},
class_assignment_objective: {
optional: true,
Expand Down Expand Up @@ -164,6 +169,8 @@ const getAnalyticsJobMeta = (config: CloneDataFrameAnalyticsConfig): AnalyticsJo
},
num_top_feature_importance_values: {
optional: true,
defaultValue: DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES,
formKey: 'numTopFeatureImportanceValues',
},
randomize_seed: {
optional: true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
EuiComboBox,
EuiComboBoxOptionOption,
EuiForm,
EuiFieldNumber,
EuiFieldText,
EuiFormRow,
EuiLink,
Expand Down Expand Up @@ -41,6 +42,7 @@ import {
ANALYSIS_CONFIG_TYPE,
DfAnalyticsExplainResponse,
FieldSelectionItem,
NUM_TOP_FEATURE_IMPORTANCE_VALUES_MIN,
TRAINING_PERCENT_MIN,
TRAINING_PERCENT_MAX,
} from '../../../../common/analytics';
Expand Down Expand Up @@ -83,6 +85,8 @@ export const CreateAnalyticsForm: FC<CreateAnalyticsFormProps> = ({ actions, sta
maxDistinctValuesError,
modelMemoryLimit,
modelMemoryLimitValidationResult,
numTopFeatureImportanceValues,
numTopFeatureImportanceValuesValid,
previousJobType,
previousSourceIndex,
sourceIndex,
Expand Down Expand Up @@ -645,6 +649,54 @@ export const CreateAnalyticsForm: FC<CreateAnalyticsFormProps> = ({ actions, sta
data-test-subj="mlAnalyticsCreateJobFlyoutTrainingPercentSlider"
/>
</EuiFormRow>
{/* num_top_feature_importance_values */}
<EuiFormRow
label={i18n.translate(
'xpack.ml.dataframe.analytics.create.numTopFeatureImportanceValuesLabel',
{
defaultMessage: 'Feature importance values',
}
)}
helpText={i18n.translate(
'xpack.ml.dataframe.analytics.create.numTopFeatureImportanceValuesHelpText',
{
defaultMessage:
'Specify the maximum number of feature importance values per document to return.',
}
)}
isInvalid={numTopFeatureImportanceValuesValid === false}
error={[
...(numTopFeatureImportanceValuesValid === false
? [
<Fragment>
{i18n.translate(
'xpack.ml.dataframe.analytics.create.numTopFeatureImportanceValuesErrorText',
{
defaultMessage:
'Invalid maximum number of feature importance values.',
}
)}
</Fragment>,
]
: []),
]}
>
<EuiFieldNumber
peteharverson marked this conversation as resolved.
Show resolved Hide resolved
aria-label={i18n.translate(
'xpack.ml.dataframe.analytics.create.numTopFeatureImportanceValuesInputAriaLabel',
{
defaultMessage: 'Maximum number of feature importance values per document.',
}
)}
data-test-subj="mlAnalyticsCreateJobFlyoutnumTopFeatureImportanceValuesInput"
disabled={false}
isInvalid={numTopFeatureImportanceValuesValid === false}
min={NUM_TOP_FEATURE_IMPORTANCE_VALUES_MIN}
onChange={e => setFormState({ numTopFeatureImportanceValues: +e.target.value })}
step={1}
value={numTopFeatureImportanceValues}
/>
</EuiFormRow>
</Fragment>
)}
<EuiFormRow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
* you may not use this file except in compliance with the Elastic License.
*/

export { DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES } from './state';
export { useCreateAnalyticsForm, CreateAnalyticsFormProps } from './use_create_analytics_form';
Loading