diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/common/analytics.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/common/analytics.ts index d77f19c0df79d..511ebb7e1647a 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/common/analytics.ts +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/common/analytics.ts @@ -33,6 +33,7 @@ interface OutlierAnalysis { interface Regression { dependent_variable: string; training_percent?: number; + num_top_feature_importance_values?: number; prediction_field_name?: string; } export interface RegressionAnalysis { @@ -44,6 +45,7 @@ interface Classification { dependent_variable: string; training_percent?: number; num_top_classes?: string; + num_top_feature_importance_values?: number; prediction_field_name?: string; } export interface ClassificationAnalysis { @@ -65,6 +67,8 @@ export const SEARCH_SIZE = 1000; export const TRAINING_PERCENT_MIN = 1; export const TRAINING_PERCENT_MAX = 100; +export const NUM_TOP_FEATURE_IMPORTANCE_VALUES_MIN = 0; + export const defaultSearchQuery = { match_all: {}, }; @@ -152,7 +156,7 @@ type AnalysisConfig = | ClassificationAnalysis | GenericAnalysis; -export const getAnalysisType = (analysis: AnalysisConfig) => { +export const getAnalysisType = (analysis: AnalysisConfig): string => { const keys = Object.keys(analysis); if (keys.length === 1) { @@ -162,7 +166,11 @@ export const getAnalysisType = (analysis: AnalysisConfig) => { return 'unknown'; }; -export const getDependentVar = (analysis: AnalysisConfig) => { +export const getDependentVar = ( + analysis: AnalysisConfig +): + | RegressionAnalysis['regression']['dependent_variable'] + | ClassificationAnalysis['classification']['dependent_variable'] => { let depVar = ''; if (isRegressionAnalysis(analysis)) { @@ -175,7 +183,11 @@ export const getDependentVar = (analysis: AnalysisConfig) => { return depVar; }; -export const getTrainingPercent = (analysis: AnalysisConfig) => { +export const getTrainingPercent = ( + analysis: AnalysisConfig +): + | RegressionAnalysis['regression']['training_percent'] + | ClassificationAnalysis['classification']['training_percent'] => { let trainingPercent; if (isRegressionAnalysis(analysis)) { @@ -188,7 +200,11 @@ export const getTrainingPercent = (analysis: AnalysisConfig) => { return trainingPercent; }; -export const getPredictionFieldName = (analysis: AnalysisConfig) => { +export const getPredictionFieldName = ( + analysis: AnalysisConfig +): + | RegressionAnalysis['regression']['prediction_field_name'] + | ClassificationAnalysis['classification']['prediction_field_name'] => { // If undefined will be defaulted to dependent_variable when config is created let predictionFieldName; if (isRegressionAnalysis(analysis) && analysis.regression.prediction_field_name !== undefined) { @@ -202,6 +218,26 @@ export const getPredictionFieldName = (analysis: AnalysisConfig) => { return predictionFieldName; }; +export const getNumTopFeatureImportanceValues = ( + analysis: AnalysisConfig +): + | RegressionAnalysis['regression']['num_top_feature_importance_values'] + | ClassificationAnalysis['classification']['num_top_feature_importance_values'] => { + let numTopFeatureImportanceValues; + if ( + isRegressionAnalysis(analysis) && + analysis.regression.num_top_feature_importance_values !== undefined + ) { + numTopFeatureImportanceValues = analysis.regression.num_top_feature_importance_values; + } else if ( + isClassificationAnalysis(analysis) && + analysis.classification.num_top_feature_importance_values !== undefined + ) { + numTopFeatureImportanceValues = analysis.classification.num_top_feature_importance_values; + } + return numTopFeatureImportanceValues; +}; + export const getPredictedFieldName = ( resultsField: string, analysis: AnalysisConfig, diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/common/fields.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/common/fields.ts index 59b42935a141d..92d8731959895 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/common/fields.ts +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/common/fields.ts @@ -7,12 +7,13 @@ import { getNestedProperty } from '../../util/object_utils'; import { DataFrameAnalyticsConfig, + getNumTopFeatureImportanceValues, getPredictedFieldName, getDependentVar, getPredictionFieldName, } from './analytics'; import { Field } from '../../../../common/types/fields'; -import { ES_FIELD_TYPES } from '../../../../../../../src/plugins/data/public'; +import { ES_FIELD_TYPES, KBN_FIELD_TYPES } from '../../../../../../../src/plugins/data/public'; import { newJobCapsService } from '../../services/new_job_capabilities_service'; export type EsId = string; @@ -254,6 +255,7 @@ export const getDefaultFieldsFromJobCaps = ( const dependentVariable = getDependentVar(jobConfig.analysis); const type = newJobCapsService.getFieldById(dependentVariable)?.type; const predictionFieldName = getPredictionFieldName(jobConfig.analysis); + const numTopFeatureImportanceValues = getNumTopFeatureImportanceValues(jobConfig.analysis); // default is 'ml' const resultsField = jobConfig.dest.results_field; @@ -261,7 +263,20 @@ export const getDefaultFieldsFromJobCaps = ( const predictedField = `${resultsField}.${ predictionFieldName ? predictionFieldName : defaultPredictionField }`; - // Only need to add these first two fields if we didn't use dest index pattern to get the fields + + const featureImportanceFields = []; + + if ((numTopFeatureImportanceValues ?? 0) > 0) { + featureImportanceFields.push( + ...fields.map(d => ({ + id: `${resultsField}.feature_importance.${d.id}`, + name: `${resultsField}.feature_importance.${d.name}`, + type: KBN_FIELD_TYPES.NUMBER, + })) + ); + } + + // Only need to add these fields if we didn't use dest index pattern to get the fields const allFields: any = needsDestIndexFields === true ? [ @@ -271,16 +286,20 @@ export const getDefaultFieldsFromJobCaps = ( type: ES_FIELD_TYPES.BOOLEAN, }, { id: predictedField, name: predictedField, type }, + ...featureImportanceFields, ] : []; allFields.push(...fields); - // @ts-ignore - allFields.sort(({ name: a }, { name: b }) => sortRegressionResultsFields(a, b, jobConfig)); - - let selectedFields = allFields - .slice(0, DEFAULT_REGRESSION_COLUMNS * 2) - .filter((field: any) => field.name === predictedField || !field.name.includes('.keyword')); + allFields.sort(({ name: a }: { name: string }, { name: b }: { name: string }) => + sortRegressionResultsFields(a, b, jobConfig) + ); + + let selectedFields = allFields.filter( + (field: any) => + field.name === predictedField || + (!field.name.includes('.keyword') && !field.name.includes('.feature_importance.')) + ); if (selectedFields.length > DEFAULT_REGRESSION_COLUMNS) { selectedFields = selectedFields.slice(0, DEFAULT_REGRESSION_COLUMNS); diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/analytics_list/action_clone.test.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/analytics_list/action_clone.test.ts index 6225bca592be3..2463da054d140 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/analytics_list/action_clone.test.ts +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/analytics_list/action_clone.test.ts @@ -25,6 +25,7 @@ describe('Analytics job clone action', () => { classification: { dependent_variable: 'y', num_top_classes: 2, + num_top_feature_importance_values: 4, prediction_field_name: 'y_prediction', training_percent: 2, randomize_seed: 6233212276062807000, @@ -90,6 +91,7 @@ describe('Analytics job clone action', () => { prediction_field_name: 'stab_prediction', training_percent: 20, randomize_seed: -2228827740028660200, + num_top_feature_importance_values: 4, }, }, analyzed_fields: { @@ -120,6 +122,7 @@ describe('Analytics job clone action', () => { classification: { dependent_variable: 'y', num_top_classes: 2, + num_top_feature_importance_values: 4, prediction_field_name: 'y_prediction', training_percent: 2, randomize_seed: 6233212276062807000, @@ -188,6 +191,7 @@ describe('Analytics job clone action', () => { prediction_field_name: 'stab_prediction', training_percent: 20, randomize_seed: -2228827740028660200, + num_top_feature_importance_values: 4, }, }, analyzed_fields: { @@ -218,6 +222,7 @@ describe('Analytics job clone action', () => { dependent_variable: 'y', training_percent: 71, max_trees: 1500, + num_top_feature_importance_values: 4, }, }, model_memory_limit: '400mb', @@ -243,6 +248,7 @@ describe('Analytics job clone action', () => { dependent_variable: 'y', training_percent: 71, maximum_number_trees: 1500, + num_top_feature_importance_values: 4, }, }, model_memory_limit: '400mb', diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/analytics_list/action_clone.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/analytics_list/action_clone.tsx index 3a0f98fc5acaa..eb1871c98764b 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/analytics_list/action_clone.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/analytics_list/action_clone.tsx @@ -11,7 +11,10 @@ import { i18n } from '@kbn/i18n'; import { DeepReadonly } from '../../../../../../../common/types/common'; import { DataFrameAnalyticsConfig, isOutlierAnalysis } from '../../../../common'; import { isClassificationAnalysis, isRegressionAnalysis } from '../../../../common/analytics'; -import { CreateAnalyticsFormProps } from '../../hooks/use_create_analytics_form'; +import { + CreateAnalyticsFormProps, + DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES, +} from '../../hooks/use_create_analytics_form'; import { State } from '../../hooks/use_create_analytics_form/state'; import { DataFrameAnalyticsListRow } from './common'; @@ -97,6 +100,8 @@ const getAnalyticsJobMeta = (config: CloneDataFrameAnalyticsConfig): AnalyticsJo }, num_top_feature_importance_values: { optional: true, + defaultValue: DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES, + formKey: 'numTopFeatureImportanceValues', }, class_assignment_objective: { optional: true, @@ -164,6 +169,8 @@ const getAnalyticsJobMeta = (config: CloneDataFrameAnalyticsConfig): AnalyticsJo }, num_top_feature_importance_values: { optional: true, + defaultValue: DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES, + formKey: 'numTopFeatureImportanceValues', }, randomize_seed: { optional: true, diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/create_analytics_form/create_analytics_form.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/create_analytics_form/create_analytics_form.tsx index 044bb9f517001..e5f30a50ed8f0 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/create_analytics_form/create_analytics_form.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/components/create_analytics_form/create_analytics_form.tsx @@ -10,6 +10,7 @@ import { EuiComboBox, EuiComboBoxOptionOption, EuiForm, + EuiFieldNumber, EuiFieldText, EuiFormRow, EuiLink, @@ -41,6 +42,7 @@ import { ANALYSIS_CONFIG_TYPE, DfAnalyticsExplainResponse, FieldSelectionItem, + NUM_TOP_FEATURE_IMPORTANCE_VALUES_MIN, TRAINING_PERCENT_MIN, TRAINING_PERCENT_MAX, } from '../../../../common/analytics'; @@ -83,6 +85,8 @@ export const CreateAnalyticsForm: FC = ({ actions, sta maxDistinctValuesError, modelMemoryLimit, modelMemoryLimitValidationResult, + numTopFeatureImportanceValues, + numTopFeatureImportanceValuesValid, previousJobType, previousSourceIndex, sourceIndex, @@ -645,6 +649,54 @@ export const CreateAnalyticsForm: FC = ({ actions, sta data-test-subj="mlAnalyticsCreateJobFlyoutTrainingPercentSlider" /> + {/* num_top_feature_importance_values */} + + {i18n.translate( + 'xpack.ml.dataframe.analytics.create.numTopFeatureImportanceValuesErrorText', + { + defaultMessage: + 'Invalid maximum number of feature importance values.', + } + )} + , + ] + : []), + ]} + > + setFormState({ numTopFeatureImportanceValues: +e.target.value })} + step={1} + value={numTopFeatureImportanceValues} + /> + )} merge(getInitialState(), { form: { @@ -34,7 +41,11 @@ const getMockState = ({ source: { index }, dest: { index: 'the-destination-index' }, analysis: { - classification: { dependent_variable: 'the-variable', training_percent: trainingPercent }, + classification: { + dependent_variable: 'the-variable', + num_top_feature_importance_values: numTopFeatureImportanceValues, + training_percent: trainingPercent, + }, }, model_memory_limit: modelMemoryLimit, }, @@ -173,6 +184,27 @@ describe('useCreateAnalyticsForm', () => { .isValid ).toBe(false); }); + + test('validateAdvancedEditor(): check num_top_feature_importance_values validation', () => { + // valid num_top_feature_importance_values value + expect( + validateAdvancedEditor( + getMockState({ index: 'the-source-index', numTopFeatureImportanceValues: 1 }) + ).isValid + ).toBe(true); + // invalid num_top_feature_importance_values numeric value + expect( + validateAdvancedEditor( + getMockState({ index: 'the-source-index', numTopFeatureImportanceValues: -1 }) + ).isValid + ).toBe(false); + // invalid training_percent numeric value if not an integer + expect( + validateAdvancedEditor( + getMockState({ index: 'the-source-index', numTopFeatureImportanceValues: 1.1 }) + ).isValid + ).toBe(false); + }); }); describe('validateMinMML', () => { @@ -194,3 +226,24 @@ describe('validateMinMML', () => { expect(validateMinMML((undefined as unknown) as string)('')).toEqual(null); }); }); + +describe('validateNumTopFeatureImportanceValues()', () => { + test('should not allow below 0', () => { + expect(validateNumTopFeatureImportanceValues(-1)).toBe(false); + }); + + test('should not allow strings', () => { + expect(validateNumTopFeatureImportanceValues('1')).toBe(false); + }); + + test('should not allow floats', () => { + expect(validateNumTopFeatureImportanceValues(0.1)).toBe(false); + expect(validateNumTopFeatureImportanceValues(1.1)).toBe(false); + expect(validateNumTopFeatureImportanceValues(-1.1)).toBe(false); + }); + + test('should allow 0 and higher', () => { + expect(validateNumTopFeatureImportanceValues(0)).toBe(true); + expect(validateNumTopFeatureImportanceValues(1)).toBe(true); + }); +}); diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/hooks/use_create_analytics_form/reducer.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/hooks/use_create_analytics_form/reducer.ts index 28d8afbcd88cc..c28a92bcb104c 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/hooks/use_create_analytics_form/reducer.ts +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/hooks/use_create_analytics_form/reducer.ts @@ -31,10 +31,12 @@ import { } from '../../../../../../../common/constants/validation'; import { getDependentVar, + getNumTopFeatureImportanceValues, getTrainingPercent, isRegressionAnalysis, isClassificationAnalysis, ANALYSIS_CONFIG_TYPE, + NUM_TOP_FEATURE_IMPORTANCE_VALUES_MIN, TRAINING_PERCENT_MIN, TRAINING_PERCENT_MAX, } from '../../../../common/analytics'; @@ -100,6 +102,19 @@ const getSourceIndexString = (state: State) => { return ''; }; +/** + * Validates num_top_feature_importance_values. Must be an integer >= 0. + */ +export const validateNumTopFeatureImportanceValues = ( + numTopFeatureImportanceValues: any +): boolean => { + return ( + typeof numTopFeatureImportanceValues === 'number' && + numTopFeatureImportanceValues >= NUM_TOP_FEATURE_IMPORTANCE_VALUES_MIN && + Number.isInteger(numTopFeatureImportanceValues) + ); +}; + export const validateAdvancedEditor = (state: State): State => { const { jobIdEmpty, @@ -147,6 +162,7 @@ export const validateAdvancedEditor = (state: State): State => { let dependentVariableEmpty = false; let excludesValid = true; let trainingPercentValid = true; + let numTopFeatureImportanceValuesValid = true; if ( jobConfig.analysis === undefined && @@ -180,6 +196,7 @@ export const validateAdvancedEditor = (state: State): State => { if ( trainingPercent !== undefined && (isNaN(trainingPercent) || + typeof trainingPercent !== 'number' || trainingPercent < TRAINING_PERCENT_MIN || trainingPercent > TRAINING_PERCENT_MAX) ) { @@ -189,7 +206,7 @@ export const validateAdvancedEditor = (state: State): State => { error: i18n.translate( 'xpack.ml.dataframe.analytics.create.advancedEditorMessage.trainingPercentInvalid', { - defaultMessage: 'The training percent must be a value between {min} and {max}.', + defaultMessage: 'The training percent must be a number between {min} and {max}.', values: { min: TRAINING_PERCENT_MIN, max: TRAINING_PERCENT_MAX, @@ -199,6 +216,28 @@ export const validateAdvancedEditor = (state: State): State => { message: '', }); } + + const numTopFeatureImportanceValues = getNumTopFeatureImportanceValues(jobConfig.analysis); + if (numTopFeatureImportanceValues !== undefined) { + numTopFeatureImportanceValuesValid = validateNumTopFeatureImportanceValues( + numTopFeatureImportanceValues + ); + if (numTopFeatureImportanceValuesValid === false) { + state.advancedEditorMessages.push({ + error: i18n.translate( + 'xpack.ml.dataframe.analytics.create.advancedEditorMessage.numTopFeatureImportanceValuesInvalid', + { + defaultMessage: + 'The value for num_top_feature_importance_values must be an integer of {min} or higher.', + values: { + min: 0, + }, + } + ), + message: '', + }); + } + } } if (sourceIndexNameEmpty) { @@ -290,6 +329,7 @@ export const validateAdvancedEditor = (state: State): State => { destinationIndexNameValid && !dependentVariableEmpty && !modelMemoryLimitEmpty && + numTopFeatureImportanceValuesValid && (!destinationIndexPatternTitleExists || !createIndexPattern); return state; @@ -343,6 +383,7 @@ const validateForm = (state: State): State => { dependentVariable, maxDistinctValuesError, modelMemoryLimit, + numTopFeatureImportanceValuesValid, } = state.form; const { estimatedModelMemoryLimit } = state; @@ -368,6 +409,7 @@ const validateForm = (state: State): State => { !destinationIndexNameEmpty && destinationIndexNameValid && !dependentVariableEmpty && + numTopFeatureImportanceValuesValid && (!destinationIndexPatternTitleExists || !createIndexPattern); return state; @@ -443,6 +485,12 @@ export function reducer(state: State, action: Action): State { newFormState.sourceIndexNameValid = Object.keys(validationMessages).length === 0; } + if (action.payload.numTopFeatureImportanceValues !== undefined) { + newFormState.numTopFeatureImportanceValuesValid = validateNumTopFeatureImportanceValues( + newFormState?.numTopFeatureImportanceValues + ); + } + return state.isAdvancedEditorEnabled ? validateAdvancedEditor({ ...state, form: newFormState }) : validateForm({ ...state, form: newFormState }); diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/hooks/use_create_analytics_form/state.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/hooks/use_create_analytics_form/state.ts index fe741fe9a92d4..01a39d2ef9f3b 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/hooks/use_create_analytics_form/state.ts +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_management/hooks/use_create_analytics_form/state.ts @@ -25,6 +25,8 @@ export enum DEFAULT_MODEL_MEMORY_LIMIT { classification = '100mb', } +export const DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES = 2; + export type EsIndexName = string; export type DependentVariable = string; export type IndexPatternTitle = string; @@ -69,6 +71,8 @@ export interface State { modelMemoryLimit: string | undefined; modelMemoryLimitUnitValid: boolean; modelMemoryLimitValidationResult: any; + numTopFeatureImportanceValues: number | undefined; + numTopFeatureImportanceValuesValid: boolean; previousJobType: null | AnalyticsJobType; previousSourceIndex: EsIndexName | undefined; sourceIndex: EsIndexName; @@ -124,6 +128,8 @@ export const getInitialState = (): State => ({ modelMemoryLimit: undefined, modelMemoryLimitUnitValid: true, modelMemoryLimitValidationResult: null, + numTopFeatureImportanceValues: DEFAULT_NUM_TOP_FEATURE_IMPORTANCE_VALUES, + numTopFeatureImportanceValuesValid: true, previousJobType: null, previousSourceIndex: undefined, sourceIndex: '', @@ -184,6 +190,7 @@ export const getJobConfigFromFormState = ( jobConfig.analysis = { [formState.jobType]: { dependent_variable: formState.dependentVariable, + num_top_feature_importance_values: formState.numTopFeatureImportanceValues, training_percent: formState.trainingPercent, }, }; @@ -218,6 +225,7 @@ export function getCloneFormStateFromJobConfig( const analysisConfig = analyticsJobConfig.analysis[jobType]; resultState.dependentVariable = analysisConfig.dependent_variable; + resultState.numTopFeatureImportanceValues = analysisConfig.num_top_feature_importance_values; resultState.trainingPercent = analysisConfig.training_percent; }