diff --git a/x-pack/plugins/ml/public/application/components/scatterplot_matrix/index.ts b/x-pack/plugins/ml/public/application/components/scatterplot_matrix/index.ts index c72b0eb5fd66e5..216b0d8d5e9920 100644 --- a/x-pack/plugins/ml/public/application/components/scatterplot_matrix/index.ts +++ b/x-pack/plugins/ml/public/application/components/scatterplot_matrix/index.ts @@ -6,6 +6,4 @@ */ export { useScatterplotFieldOptions } from './use_scatterplot_field_options'; -export { LEGEND_TYPES } from './scatterplot_matrix_vega_lite_spec'; -export { ScatterplotMatrix } from './scatterplot_matrix'; -export type { ScatterplotMatrixViewProps as ScatterplotMatrixProps } from './scatterplot_matrix_view'; +export { ScatterplotMatrix, ScatterplotMatrixProps } from './scatterplot_matrix'; diff --git a/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_view.scss b/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix.scss similarity index 100% rename from x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_view.scss rename to x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix.scss diff --git a/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix.tsx b/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix.tsx index 8a10fd5574ba59..a4f68c84ba81f1 100644 --- a/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix.tsx +++ b/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix.tsx @@ -5,15 +5,305 @@ * 2.0. */ -import React, { FC, Suspense } from 'react'; +import React, { useMemo, useEffect, useState, FC } from 'react'; -import type { ScatterplotMatrixViewProps } from './scatterplot_matrix_view'; -import { ScatterplotMatrixLoading } from './scatterplot_matrix_loading'; +import { + EuiComboBox, + EuiComboBoxOptionOption, + EuiFlexGroup, + EuiFlexItem, + EuiFormRow, + EuiSelect, + EuiSwitch, +} from '@elastic/eui'; -const ScatterplotMatrixLazy = React.lazy(() => import('./scatterplot_matrix_view')); +import { i18n } from '@kbn/i18n'; -export const ScatterplotMatrix: FC = (props) => ( - }> - - -); +import type { SearchResponse7 } from '../../../../common/types/es_client'; +import type { ResultsSearchQuery } from '../../data_frame_analytics/common/analytics'; + +import { useMlApiContext } from '../../contexts/kibana'; + +import { getProcessedFields } from '../data_grid'; +import { useCurrentEuiTheme } from '../color_range_legend'; + +// Separate imports for lazy loadable VegaChart and related code +import { VegaChart } from '../vega_chart'; +import type { LegendType } from '../vega_chart/common'; +import { VegaChartLoading } from '../vega_chart/vega_chart_loading'; + +import { + getScatterplotMatrixVegaLiteSpec, + OUTLIER_SCORE_FIELD, +} from './scatterplot_matrix_vega_lite_spec'; + +import './scatterplot_matrix.scss'; + +const SCATTERPLOT_MATRIX_DEFAULT_FIELDS = 4; +const SCATTERPLOT_MATRIX_DEFAULT_FETCH_SIZE = 1000; +const SCATTERPLOT_MATRIX_DEFAULT_FETCH_MIN_SIZE = 1; +const SCATTERPLOT_MATRIX_DEFAULT_FETCH_MAX_SIZE = 10000; + +const TOGGLE_ON = i18n.translate('xpack.ml.splom.toggleOn', { + defaultMessage: 'On', +}); +const TOGGLE_OFF = i18n.translate('xpack.ml.splom.toggleOff', { + defaultMessage: 'Off', +}); + +const sampleSizeOptions = [100, 1000, 10000].map((d) => ({ value: d, text: '' + d })); + +export interface ScatterplotMatrixProps { + fields: string[]; + index: string; + resultsField?: string; + color?: string; + legendType?: LegendType; + searchQuery?: ResultsSearchQuery; +} + +export const ScatterplotMatrix: FC = ({ + fields: allFields, + index, + resultsField, + color, + legendType, + searchQuery, +}) => { + const { esSearch } = useMlApiContext(); + + // dynamicSize is optionally used for outlier charts where the scatterplot marks + // are sized according to outlier_score + const [dynamicSize, setDynamicSize] = useState(false); + + // used to give the use the option to customize the fields used for the matrix axes + const [fields, setFields] = useState([]); + + useEffect(() => { + const defaultFields = + allFields.length > SCATTERPLOT_MATRIX_DEFAULT_FIELDS + ? allFields.slice(0, SCATTERPLOT_MATRIX_DEFAULT_FIELDS) + : allFields; + setFields(defaultFields); + }, [allFields]); + + // the amount of documents to be fetched + const [fetchSize, setFetchSize] = useState(SCATTERPLOT_MATRIX_DEFAULT_FETCH_SIZE); + // flag to add a random score to the ES query to fetch documents + const [randomizeQuery, setRandomizeQuery] = useState(false); + + const [isLoading, setIsLoading] = useState(false); + + // contains the fetched documents and columns to be passed on to the Vega spec. + const [splom, setSplom] = useState<{ items: any[]; columns: string[] } | undefined>(); + + // formats the array of field names for EuiComboBox + const fieldOptions = useMemo( + () => + allFields.map((d) => ({ + label: d, + })), + [allFields] + ); + + const fieldsOnChange = (newFields: EuiComboBoxOptionOption[]) => { + setFields(newFields.map((d) => d.label)); + }; + + const fetchSizeOnChange = (e: React.ChangeEvent) => { + setFetchSize( + Math.min( + Math.max(parseInt(e.target.value, 10), SCATTERPLOT_MATRIX_DEFAULT_FETCH_MIN_SIZE), + SCATTERPLOT_MATRIX_DEFAULT_FETCH_MAX_SIZE + ) + ); + }; + + const randomizeQueryOnChange = () => { + setRandomizeQuery(!randomizeQuery); + }; + + const dynamicSizeOnChange = () => { + setDynamicSize(!dynamicSize); + }; + + const { euiTheme } = useCurrentEuiTheme(); + + useEffect(() => { + if (fields.length === 0) { + setSplom(undefined); + setIsLoading(false); + return; + } + + async function fetchSplom(options: { didCancel: boolean }) { + setIsLoading(true); + try { + const queryFields = [ + ...fields, + ...(color !== undefined ? [color] : []), + ...(legendType !== undefined ? [] : [`${resultsField}.${OUTLIER_SCORE_FIELD}`]), + ]; + + const queryFallback = searchQuery !== undefined ? searchQuery : { match_all: {} }; + const query = randomizeQuery + ? { + function_score: { + query: queryFallback, + random_score: { seed: 10, field: '_seq_no' }, + }, + } + : queryFallback; + + const resp: SearchResponse7 = await esSearch({ + index, + body: { + fields: queryFields, + _source: false, + query, + from: 0, + size: fetchSize, + }, + }); + + if (!options.didCancel) { + const items = resp.hits.hits.map((d) => + getProcessedFields(d.fields, (key: string) => + key.startsWith(`${resultsField}.feature_importance`) + ) + ); + + setSplom({ columns: fields, items }); + setIsLoading(false); + } + } catch (e) { + // TODO error handling + setIsLoading(false); + } + } + + const options = { didCancel: false }; + fetchSplom(options); + return () => { + options.didCancel = true; + }; + // stringify the fields array and search, otherwise the comparator will trigger on new but identical instances. + }, [fetchSize, JSON.stringify({ fields, searchQuery }), index, randomizeQuery, resultsField]); + + const vegaSpec = useMemo(() => { + if (splom === undefined) { + return; + } + + const { items, columns } = splom; + + const values = + resultsField !== undefined + ? items + : items.map((d) => { + d[`${resultsField}.${OUTLIER_SCORE_FIELD}`] = 0; + return d; + }); + + return getScatterplotMatrixVegaLiteSpec( + values, + columns, + euiTheme, + resultsField, + color, + legendType, + dynamicSize + ); + }, [resultsField, splom, color, legendType, dynamicSize]); + + return ( + <> + {splom === undefined || vegaSpec === undefined ? ( + + ) : ( +
+ + + + ({ + label: d, + }))} + onChange={fieldsOnChange} + isClearable={true} + data-test-subj="mlScatterplotMatrixFieldsComboBox" + /> + + + + + + + + + + + + + {resultsField !== undefined && legendType === undefined && ( + + + + + + )} + + + +
+ )} + + ); +}; diff --git a/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_vega_lite_spec.test.ts b/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_vega_lite_spec.test.ts index 44fba189e856cf..c963b7509139b8 100644 --- a/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_vega_lite_spec.test.ts +++ b/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_vega_lite_spec.test.ts @@ -10,13 +10,14 @@ import { compile } from 'vega-lite/build-es5/vega-lite'; import euiThemeLight from '@elastic/eui/dist/eui_theme_light.json'; +import { LEGEND_TYPES } from '../vega_chart/common'; + import { getColorSpec, getScatterplotMatrixVegaLiteSpec, COLOR_OUTLIER, COLOR_RANGE_NOMINAL, DEFAULT_COLOR, - LEGEND_TYPES, } from './scatterplot_matrix_vega_lite_spec'; describe('getColorSpec()', () => { diff --git a/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_vega_lite_spec.ts b/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_vega_lite_spec.ts index e476123ad0f2a2..f99aa7c5c3de86 100644 --- a/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_vega_lite_spec.ts +++ b/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_vega_lite_spec.ts @@ -15,11 +15,7 @@ import { euiPaletteColorBlind, euiPaletteNegative, euiPalettePositive } from '@e import { i18n } from '@kbn/i18n'; -export const LEGEND_TYPES = { - NOMINAL: 'nominal', - QUANTITATIVE: 'quantitative', -} as const; -export type LegendType = typeof LEGEND_TYPES[keyof typeof LEGEND_TYPES]; +import { LegendType, LEGEND_TYPES } from '../vega_chart/common'; export const OUTLIER_SCORE_FIELD = 'outlier_score'; diff --git a/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_view.tsx b/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_view.tsx deleted file mode 100644 index 7d32992ace84da..00000000000000 --- a/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_view.tsx +++ /dev/null @@ -1,324 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -import React, { useMemo, useEffect, useState, FC } from 'react'; - -// There is still an issue with Vega Lite's typings with the strict mode Kibana is using. -// @ts-ignore -import { compile } from 'vega-lite/build-es5/vega-lite'; -import { parse, View, Warn } from 'vega'; -import { Handler } from 'vega-tooltip'; - -import { - htmlIdGenerator, - EuiComboBox, - EuiComboBoxOptionOption, - EuiFlexGroup, - EuiFlexItem, - EuiFormRow, - EuiSelect, - EuiSwitch, -} from '@elastic/eui'; - -import { i18n } from '@kbn/i18n'; - -import type { SearchResponse7 } from '../../../../common/types/es_client'; -import type { ResultsSearchQuery } from '../../data_frame_analytics/common/analytics'; - -import { useMlApiContext } from '../../contexts/kibana'; - -import { getProcessedFields } from '../data_grid'; -import { useCurrentEuiTheme } from '../color_range_legend'; - -import { ScatterplotMatrixLoading } from './scatterplot_matrix_loading'; - -import { - getScatterplotMatrixVegaLiteSpec, - LegendType, - OUTLIER_SCORE_FIELD, -} from './scatterplot_matrix_vega_lite_spec'; - -import './scatterplot_matrix_view.scss'; - -const SCATTERPLOT_MATRIX_DEFAULT_FIELDS = 4; -const SCATTERPLOT_MATRIX_DEFAULT_FETCH_SIZE = 1000; -const SCATTERPLOT_MATRIX_DEFAULT_FETCH_MIN_SIZE = 1; -const SCATTERPLOT_MATRIX_DEFAULT_FETCH_MAX_SIZE = 10000; - -const TOGGLE_ON = i18n.translate('xpack.ml.splom.toggleOn', { - defaultMessage: 'On', -}); -const TOGGLE_OFF = i18n.translate('xpack.ml.splom.toggleOff', { - defaultMessage: 'Off', -}); - -const sampleSizeOptions = [100, 1000, 10000].map((d) => ({ value: d, text: '' + d })); - -export interface ScatterplotMatrixViewProps { - fields: string[]; - index: string; - resultsField?: string; - color?: string; - legendType?: LegendType; - searchQuery?: ResultsSearchQuery; -} - -export const ScatterplotMatrixView: FC = ({ - fields: allFields, - index, - resultsField, - color, - legendType, - searchQuery, -}) => { - const { esSearch } = useMlApiContext(); - - // dynamicSize is optionally used for outlier charts where the scatterplot marks - // are sized according to outlier_score - const [dynamicSize, setDynamicSize] = useState(false); - - // used to give the use the option to customize the fields used for the matrix axes - const [fields, setFields] = useState([]); - - useEffect(() => { - const defaultFields = - allFields.length > SCATTERPLOT_MATRIX_DEFAULT_FIELDS - ? allFields.slice(0, SCATTERPLOT_MATRIX_DEFAULT_FIELDS) - : allFields; - setFields(defaultFields); - }, [allFields]); - - // the amount of documents to be fetched - const [fetchSize, setFetchSize] = useState(SCATTERPLOT_MATRIX_DEFAULT_FETCH_SIZE); - // flag to add a random score to the ES query to fetch documents - const [randomizeQuery, setRandomizeQuery] = useState(false); - - const [isLoading, setIsLoading] = useState(false); - - // contains the fetched documents and columns to be passed on to the Vega spec. - const [splom, setSplom] = useState<{ items: any[]; columns: string[] } | undefined>(); - - // formats the array of field names for EuiComboBox - const fieldOptions = useMemo( - () => - allFields.map((d) => ({ - label: d, - })), - [allFields] - ); - - const fieldsOnChange = (newFields: EuiComboBoxOptionOption[]) => { - setFields(newFields.map((d) => d.label)); - }; - - const fetchSizeOnChange = (e: React.ChangeEvent) => { - setFetchSize( - Math.min( - Math.max(parseInt(e.target.value, 10), SCATTERPLOT_MATRIX_DEFAULT_FETCH_MIN_SIZE), - SCATTERPLOT_MATRIX_DEFAULT_FETCH_MAX_SIZE - ) - ); - }; - - const randomizeQueryOnChange = () => { - setRandomizeQuery(!randomizeQuery); - }; - - const dynamicSizeOnChange = () => { - setDynamicSize(!dynamicSize); - }; - - const { euiTheme } = useCurrentEuiTheme(); - - useEffect(() => { - async function fetchSplom(options: { didCancel: boolean }) { - setIsLoading(true); - try { - const queryFields = [ - ...fields, - ...(color !== undefined ? [color] : []), - ...(legendType !== undefined ? [] : [`${resultsField}.${OUTLIER_SCORE_FIELD}`]), - ]; - - const queryFallback = searchQuery !== undefined ? searchQuery : { match_all: {} }; - const query = randomizeQuery - ? { - function_score: { - query: queryFallback, - random_score: { seed: 10, field: '_seq_no' }, - }, - } - : queryFallback; - - const resp: SearchResponse7 = await esSearch({ - index, - body: { - fields: queryFields, - _source: false, - query, - from: 0, - size: fetchSize, - }, - }); - - if (!options.didCancel) { - const items = resp.hits.hits.map((d) => - getProcessedFields(d.fields, (key: string) => - key.startsWith(`${resultsField}.feature_importance`) - ) - ); - - setSplom({ columns: fields, items }); - setIsLoading(false); - } - } catch (e) { - // TODO error handling - setIsLoading(false); - } - } - - const options = { didCancel: false }; - fetchSplom(options); - return () => { - options.didCancel = true; - }; - // stringify the fields array and search, otherwise the comparator will trigger on new but identical instances. - }, [fetchSize, JSON.stringify({ fields, searchQuery }), index, randomizeQuery, resultsField]); - - const htmlId = useMemo(() => htmlIdGenerator()(), []); - - useEffect(() => { - if (splom === undefined) { - return; - } - - const { items, columns } = splom; - - const values = - resultsField !== undefined - ? items - : items.map((d) => { - d[`${resultsField}.${OUTLIER_SCORE_FIELD}`] = 0; - return d; - }); - - const vegaSpec = getScatterplotMatrixVegaLiteSpec( - values, - columns, - euiTheme, - resultsField, - color, - legendType, - dynamicSize - ); - - const vgSpec = compile(vegaSpec).spec; - - const view = new View(parse(vgSpec)) - .logLevel(Warn) - .renderer('canvas') - .tooltip(new Handler().call) - .initialize(`#${htmlId}`); - - view.runAsync(); // evaluate and render the view - }, [resultsField, splom, color, legendType, dynamicSize]); - - return ( - <> - {splom === undefined ? ( - - ) : ( - <> - - - - ({ - label: d, - }))} - onChange={fieldsOnChange} - isClearable={true} - data-test-subj="mlScatterplotMatrixFieldsComboBox" - /> - - - - - - - - - - - - - {resultsField !== undefined && legendType === undefined && ( - - - - - - )} - - -
- - )} - - ); -}; - -// required for dynamic import using React.lazy() -// eslint-disable-next-line import/no-default-export -export default ScatterplotMatrixView; diff --git a/x-pack/plugins/ml/public/application/components/vega_chart/common.ts b/x-pack/plugins/ml/public/application/components/vega_chart/common.ts new file mode 100644 index 00000000000000..79254788ce7a69 --- /dev/null +++ b/x-pack/plugins/ml/public/application/components/vega_chart/common.ts @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +export const LEGEND_TYPES = { + NOMINAL: 'nominal', + QUANTITATIVE: 'quantitative', +} as const; +export type LegendType = typeof LEGEND_TYPES[keyof typeof LEGEND_TYPES]; diff --git a/x-pack/plugins/ml/public/application/components/vega_chart/index.ts b/x-pack/plugins/ml/public/application/components/vega_chart/index.ts new file mode 100644 index 00000000000000..f1d5c3ed4523bc --- /dev/null +++ b/x-pack/plugins/ml/public/application/components/vega_chart/index.ts @@ -0,0 +1,11 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +// Make sure to only export the component we can lazy load here. +// Code from other files in this directory should be imported directly from the file, +// otherwise we break the bundling approach using lazy loading. +export { VegaChart } from './vega_chart'; diff --git a/x-pack/plugins/ml/public/application/components/vega_chart/vega_chart.tsx b/x-pack/plugins/ml/public/application/components/vega_chart/vega_chart.tsx new file mode 100644 index 00000000000000..ab175908d9d797 --- /dev/null +++ b/x-pack/plugins/ml/public/application/components/vega_chart/vega_chart.tsx @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React, { FC, Suspense } from 'react'; + +import { VegaChartLoading } from './vega_chart_loading'; +import type { VegaChartViewProps } from './vega_chart_view'; + +const VegaChartView = React.lazy(() => import('./vega_chart_view')); + +export const VegaChart: FC = (props) => ( + }> + + +); diff --git a/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_loading.tsx b/x-pack/plugins/ml/public/application/components/vega_chart/vega_chart_loading.tsx similarity index 91% rename from x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_loading.tsx rename to x-pack/plugins/ml/public/application/components/vega_chart/vega_chart_loading.tsx index cdb4d99b041d54..8a5c1575f94d65 100644 --- a/x-pack/plugins/ml/public/application/components/scatterplot_matrix/scatterplot_matrix_loading.tsx +++ b/x-pack/plugins/ml/public/application/components/vega_chart/vega_chart_loading.tsx @@ -9,7 +9,7 @@ import React from 'react'; import { EuiLoadingSpinner, EuiSpacer, EuiText } from '@elastic/eui'; -export const ScatterplotMatrixLoading = () => { +export const VegaChartLoading = () => { return ( diff --git a/x-pack/plugins/ml/public/application/components/vega_chart/vega_chart_view.tsx b/x-pack/plugins/ml/public/application/components/vega_chart/vega_chart_view.tsx new file mode 100644 index 00000000000000..7774def574b696 --- /dev/null +++ b/x-pack/plugins/ml/public/application/components/vega_chart/vega_chart_view.tsx @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React, { useMemo, useEffect, FC } from 'react'; + +// There is still an issue with Vega Lite's typings with the strict mode Kibana is using. +// @ts-ignore +import type { TopLevelSpec } from 'vega-lite/build-es5/vega-lite'; + +// There is still an issue with Vega Lite's typings with the strict mode Kibana is using. +// @ts-ignore +import { compile } from 'vega-lite/build-es5/vega-lite'; +import { parse, View, Warn } from 'vega'; +import { Handler } from 'vega-tooltip'; + +import { htmlIdGenerator } from '@elastic/eui'; + +export interface VegaChartViewProps { + vegaSpec: TopLevelSpec; +} + +export const VegaChartView: FC = ({ vegaSpec }) => { + const htmlId = useMemo(() => htmlIdGenerator()(), []); + + useEffect(() => { + const vgSpec = compile(vegaSpec).spec; + + const view = new View(parse(vgSpec)) + .logLevel(Warn) + .renderer('canvas') + .tooltip(new Handler().call) + .initialize(`#${htmlId}`); + + view.runAsync(); // evaluate and render the view + }, [vegaSpec]); + + return
; +}; + +// required for dynamic import using React.lazy() +// eslint-disable-next-line import/no-default-export +export default VegaChartView; diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/common/analytics.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/common/analytics.ts index 4f1799ed26f872..1c13177e44e7fc 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/common/analytics.ts +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/common/analytics.ts @@ -154,11 +154,21 @@ export interface ConfusionMatrix { other_predicted_class_doc_count: number; } +export interface RocCurveItem { + fpr: number; + threshold: number; + tpr: number; +} + export interface ClassificationEvaluateResponse { classification: { - multiclass_confusion_matrix: { + multiclass_confusion_matrix?: { confusion_matrix: ConfusionMatrix[]; }; + auc_roc?: { + curve?: RocCurveItem[]; + value: number; + }; }; } @@ -244,7 +254,8 @@ export const isClassificationEvaluateResponse = ( return ( keys.length === 1 && keys[0] === ANALYSIS_CONFIG_TYPE.CLASSIFICATION && - arg?.classification?.multiclass_confusion_matrix !== undefined + (arg?.classification?.multiclass_confusion_matrix !== undefined || + arg?.classification?.auc_roc !== undefined) ); }; @@ -422,7 +433,8 @@ export enum REGRESSION_STATS { interface EvaluateMetrics { classification: { - multiclass_confusion_matrix: object; + multiclass_confusion_matrix?: object; + auc_roc?: { include_curve: boolean; class_name: string }; }; regression: { r_squared: object; @@ -442,6 +454,8 @@ interface LoadEvalDataConfig { ignoreDefaultQuery?: boolean; jobType: DataFrameAnalysisConfigType; requiresKeyword?: boolean; + rocCurveClassName?: string; + includeMulticlassConfusionMatrix?: boolean; } export const loadEvalData = async ({ @@ -454,6 +468,8 @@ export const loadEvalData = async ({ ignoreDefaultQuery, jobType, requiresKeyword, + rocCurveClassName, + includeMulticlassConfusionMatrix = true, }: LoadEvalDataConfig) => { const results: LoadEvaluateResult = { success: false, eval: null, error: null }; const defaultPredictionField = `${dependentVariable}_prediction`; @@ -469,7 +485,10 @@ export const loadEvalData = async ({ const metrics: EvaluateMetrics = { classification: { - multiclass_confusion_matrix: {}, + ...(includeMulticlassConfusionMatrix ? { multiclass_confusion_matrix: {} } : {}), + ...(rocCurveClassName !== undefined + ? { auc_roc: { include_curve: true, class_name: rocCurveClassName } } + : {}), }, regression: { r_squared: {}, diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/common/get_scatterplot_matrix_legend_type.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/common/get_scatterplot_matrix_legend_type.ts index a8b95a415ea539..2113f9385c5ef5 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/common/get_scatterplot_matrix_legend_type.ts +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/common/get_scatterplot_matrix_legend_type.ts @@ -9,7 +9,7 @@ import { ANALYSIS_CONFIG_TYPE } from './analytics'; import { AnalyticsJobType } from '../pages/analytics_management/hooks/use_create_analytics_form/state'; -import { LEGEND_TYPES } from '../../components/scatterplot_matrix/scatterplot_matrix_vega_lite_spec'; +import { LEGEND_TYPES } from '../../components/vega_chart/common'; export const getScatterplotMatrixLegendType = (jobType: AnalyticsJobType | 'unknown') => { switch (jobType) { diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/_classification_exploration.scss b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/_classification_exploration.scss index d1c507c5241d5f..73ced778821cfa 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/_classification_exploration.scss +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/_classification_exploration.scss @@ -1,3 +1,6 @@ +/* Fixed width so we can align it with the padding of the AUC ROC chart. */ +$labelColumnWidth: 80px; + /* Workaround for EuiDataGrid within a Flex Layout, this tricks browsers treating the width as a px value instead of % @@ -6,7 +9,7 @@ width: 100%; } -.mlDataFrameAnalyticsClassification__confusionMatrix { +.mlDataFrameAnalyticsClassification__evaluateSectionContent { padding: 0 5%; } @@ -14,7 +17,7 @@ The following two classes are a workaround to avoid having EuiDataGrid in a flex layout and just uses a legacy approach for a two column layout so we don't break IE11. */ -.mlDataFrameAnalyticsClassification__confusionMatrix:after { +.mlDataFrameAnalyticsClassification__evaluateSectionContent:after { content: ''; display: table; clear: both; @@ -22,7 +25,7 @@ .mlDataFrameAnalyticsClassification__actualLabel { float: left; - width: 8%; + width: $labelColumnWidth; padding-top: $euiSize * 4; } @@ -32,7 +35,7 @@ .mlDataFrameAnalyticsClassification__dataGridMinWidth { float: left; min-width: 480px; - width: 92%; + width: calc(100% - #{$labelColumnWidth}); .euiDataGridRowCell--boolean { text-transform: none; diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/evaluate_panel.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/evaluate_panel.tsx index b7dec4e5a435ee..20866bf43a2f46 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/evaluate_panel.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/evaluate_panel.tsx @@ -21,26 +21,20 @@ import { EuiTitle, } from '@elastic/eui'; import { useMlKibana } from '../../../../../contexts/kibana'; + +// Separate imports for lazy loadable VegaChart and related code +import { VegaChart } from '../../../../../components/vega_chart'; +import { VegaChartLoading } from '../../../../../components/vega_chart/vega_chart_loading'; + import { ErrorCallout } from '../error_callout'; -import { - getDependentVar, - getPredictionFieldName, - loadEvalData, - loadDocsCount, - DataFrameAnalyticsConfig, -} from '../../../../common'; -import { isKeywordAndTextType } from '../../../../common/fields'; +import { getDependentVar, DataFrameAnalyticsConfig } from '../../../../common'; import { DataFrameTaskStateType } from '../../../analytics_management/components/analytics_list/common'; -import { - isResultsSearchBoolQuery, - isClassificationEvaluateResponse, - ConfusionMatrix, - ResultsSearchQuery, - ANALYSIS_CONFIG_TYPE, -} from '../../../../common/analytics'; +import { ResultsSearchQuery } from '../../../../common/analytics'; import { ExpandableSection, HEADER_ITEMS_LOADING } from '../expandable_section'; +import { getRocCurveChartVegaLiteSpec } from './get_roc_curve_chart_vega_lite_spec'; + import { getColumnData, ACTUAL_CLASS_ID, @@ -48,6 +42,10 @@ import { getTrailingControlColumns, } from './column_data'; +import { isTrainingFilter } from './is_training_filter'; +import { useRocCurve } from './use_roc_curve'; +import { useConfusionMatrix } from './use_confusion_matrix'; + export interface EvaluatePanelProps { jobConfig: DataFrameAnalyticsConfig; jobStatus?: DataFrameTaskStateType; @@ -81,7 +79,7 @@ const trainingDatasetHelpText = i18n.translate( } ); -function getHelpText(dataSubsetTitle: string) { +function getHelpText(dataSubsetTitle: string): string { let helpText = entireDatasetHelpText; if (dataSubsetTitle === SUBSET_TITLE.TESTING) { helpText = testingDatasetHelpText; @@ -95,77 +93,36 @@ export const EvaluatePanel: FC = ({ jobConfig, jobStatus, se const { services: { docLinks }, } = useMlKibana(); - const [isLoading, setIsLoading] = useState(false); - const [confusionMatrixData, setConfusionMatrixData] = useState([]); + const [columns, setColumns] = useState([]); const [columnsData, setColumnsData] = useState([]); const [showFullColumns, setShowFullColumns] = useState(false); const [popoverContents, setPopoverContents] = useState([]); - const [docsCount, setDocsCount] = useState(null); - const [error, setError] = useState(null); const [dataSubsetTitle, setDataSubsetTitle] = useState(SUBSET_TITLE.ENTIRE); // Column visibility - const [visibleColumns, setVisibleColumns] = useState(() => + const [visibleColumns, setVisibleColumns] = useState(() => columns.map(({ id }: { id: string }) => id) ); - const index = jobConfig.dest.index; - const dependentVariable = getDependentVar(jobConfig.analysis); - const predictionFieldName = getPredictionFieldName(jobConfig.analysis); - // default is 'ml' const resultsField = jobConfig.dest.results_field; - let requiresKeyword = false; + const isTraining = isTrainingFilter(searchQuery, resultsField); - const loadData = async ({ isTraining }: { isTraining: boolean | undefined }) => { - setIsLoading(true); - - try { - requiresKeyword = isKeywordAndTextType(dependentVariable); - } catch (e) { - // Additional error handling due to missing field type is handled by loadEvalData - console.error('Unable to load new field types', error); // eslint-disable-line no-console - } - - const evalData = await loadEvalData({ - isTraining, - index, - dependentVariable, - resultsField, - predictionFieldName, - searchQuery, - jobType: ANALYSIS_CONFIG_TYPE.CLASSIFICATION, - requiresKeyword, - }); - - const docsCountResp = await loadDocsCount({ - isTraining, - searchQuery, - resultsField, - destIndex: jobConfig.dest.index, - }); - - if ( - evalData.success === true && - evalData.eval && - isClassificationEvaluateResponse(evalData.eval) - ) { - const confusionMatrix = - evalData.eval?.classification?.multiclass_confusion_matrix?.confusion_matrix; - setError(null); - setConfusionMatrixData(confusionMatrix || []); - setIsLoading(false); - } else { - setIsLoading(false); - setConfusionMatrixData([]); - setError(evalData.error); - } + const { + confusionMatrixData, + docsCount, + error: errorConfusionMatrix, + isLoading: isLoadingConfusionMatrix, + } = useConfusionMatrix(jobConfig, searchQuery); - if (docsCountResp.success === true) { - setDocsCount(docsCountResp.docsCount); + useEffect(() => { + if (isTraining === undefined) { + setDataSubsetTitle(SUBSET_TITLE.ENTIRE); } else { - setDocsCount(null); + setDataSubsetTitle( + isTraining && isTraining === true ? SUBSET_TITLE.TRAINING : SUBSET_TITLE.TESTING + ); } - }; + }, [isTraining]); useEffect(() => { if (confusionMatrixData.length > 0) { @@ -198,48 +155,12 @@ export const EvaluatePanel: FC = ({ jobConfig, jobStatus, se } }, [confusionMatrixData]); - useEffect(() => { - let isTraining: boolean | undefined; - const query = - isResultsSearchBoolQuery(searchQuery) && (searchQuery.bool.should || searchQuery.bool.filter); - - if (query !== undefined && query !== false) { - for (let i = 0; i < query.length; i++) { - const clause = query[i]; - - if (clause.match && clause.match[`${resultsField}.is_training`] !== undefined) { - isTraining = clause.match[`${resultsField}.is_training`]; - break; - } else if ( - clause.bool && - (clause.bool.should !== undefined || clause.bool.filter !== undefined) - ) { - const innerQuery = clause.bool.should || clause.bool.filter; - if (innerQuery !== undefined) { - for (let j = 0; j < innerQuery.length; j++) { - const innerClause = innerQuery[j]; - if ( - innerClause.match && - innerClause.match[`${resultsField}.is_training`] !== undefined - ) { - isTraining = innerClause.match[`${resultsField}.is_training`]; - break; - } - } - } - } - } - } - if (isTraining === undefined) { - setDataSubsetTitle(SUBSET_TITLE.ENTIRE); - } else { - setDataSubsetTitle( - isTraining && isTraining === true ? SUBSET_TITLE.TRAINING : SUBSET_TITLE.TESTING - ); - } - - loadData({ isTraining }); - }, [JSON.stringify(searchQuery)]); + const { + rocCurveData, + classificationClasses, + error: errorRocCurve, + isLoading: isLoadingRocCurve, + } = useRocCurve(jobConfig, searchQuery, visibleColumns); const renderCellValue = ({ rowIndex, @@ -312,7 +233,7 @@ export const EvaluatePanel: FC = ({ jobConfig, jobStatus, se } headerItems={ - !isLoading + !isLoadingConfusionMatrix ? [ ...(jobStatus !== undefined ? [ @@ -348,94 +269,149 @@ export const EvaluatePanel: FC = ({ jobConfig, jobStatus, se } contentPadding={true} content={ - !isLoading ? ( - <> - {error !== null && } - {error === null && ( - <> - - - {getHelpText(dataSubsetTitle)} - - - - - - {/* BEGIN TABLE ELEMENTS */} - -
-
- - - -
-
- {columns.length > 0 && columnsData.length > 0 && ( - <> -
- - - -
- - + {!isLoadingConfusionMatrix ? ( + <> + {errorConfusionMatrix !== null && } + {errorConfusionMatrix === null && ( + <> + + + {getHelpText(dataSubsetTitle)} + + + + + + {/* BEGIN TABLE ELEMENTS */} + +
+
+ + - - )} + +
+
+ {columns.length > 0 && columnsData.length > 0 && ( + <> +
+ + + +
+ + + + )} +
-
- - )} - {/* END TABLE ELEMENTS */} - - ) : null + {/* END TABLE ELEMENTS */} + + )} + + ) : null} + {/* AUC ROC Chart */} + + + + + + + + + + + + {Array.isArray(errorRocCurve) && ( + + {errorRocCurve.map((e) => ( + <> + {e} +
+ + ))} + + } + /> + )} + {!isLoadingRocCurve && errorRocCurve === null && rocCurveData.length > 0 && ( +
+ +
+ )} + {isLoadingRocCurve && } + } /> diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/get_roc_curve_chart_vega_lite_spec.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/get_roc_curve_chart_vega_lite_spec.tsx new file mode 100644 index 00000000000000..b9e9c5720e5aa9 --- /dev/null +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/get_roc_curve_chart_vega_lite_spec.tsx @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +// There is still an issue with Vega Lite's typings with the strict mode Kibana is using. +// @ts-ignore +import type { TopLevelSpec } from 'vega-lite/build-es5/vega-lite'; + +import { euiPaletteColorBlind, euiPaletteGray } from '@elastic/eui'; + +import { i18n } from '@kbn/i18n'; + +import { LEGEND_TYPES } from '../../../../../components/vega_chart/common'; + +import { RocCurveItem } from '../../../../common/analytics'; + +const GRAY = euiPaletteGray(1)[0]; +const BASELINE = 'baseline'; +const SIZE = 300; + +// returns a custom color range that includes gray for the baseline +function getColorRangeNominal(classificationClasses: string[]) { + const legendItems = [...classificationClasses, BASELINE].sort(); + const baselineIndex = legendItems.indexOf(BASELINE); + + const colorRangeNominal = euiPaletteColorBlind({ rotations: 2 }).slice( + 0, + classificationClasses.length + ); + + colorRangeNominal.splice(baselineIndex, 0, GRAY); + + return colorRangeNominal; +} + +export interface RocCurveDataRow extends RocCurveItem { + class_name: string; +} + +export const getRocCurveChartVegaLiteSpec = ( + classificationClasses: string[], + data: RocCurveDataRow[], + legendTitle: string +): TopLevelSpec => { + // we append two rows which make up the data for the diagonal baseline + data.push({ tpr: 0, fpr: 0, threshold: 1, class_name: BASELINE }); + data.push({ tpr: 1, fpr: 1, threshold: 1, class_name: BASELINE }); + + const colorRangeNominal = getColorRangeNominal(classificationClasses); + + return { + $schema: 'https://vega.github.io/schema/vega-lite/v4.8.1.json', + // Left padding of 45px to align the left axis of the chart with the confusion matrix above. + padding: { left: 45, top: 0, right: 0, bottom: 0 }, + config: { + legend: { + orient: 'right', + }, + view: { + continuousHeight: SIZE, + continuousWidth: SIZE, + }, + }, + data: { + name: 'roc-curve-data', + }, + datasets: { + 'roc-curve-data': data, + }, + encoding: { + color: { + field: 'class_name', + type: LEGEND_TYPES.NOMINAL, + scale: { + range: colorRangeNominal, + }, + legend: { + title: legendTitle, + }, + }, + size: { + value: 2, + }, + strokeDash: { + condition: { + test: `(datum.class_name === '${BASELINE}')`, + value: [5, 5], + }, + value: [0], + }, + x: { + field: 'fpr', + sort: null, + title: i18n.translate('xpack.ml.dataframe.analytics.rocChartSpec.xAxisTitle', { + defaultMessage: 'False Positive Rate (FPR)', + }), + type: 'quantitative', + axis: { + tickColor: GRAY, + labelColor: GRAY, + domainColor: GRAY, + titleColor: GRAY, + }, + }, + y: { + field: 'tpr', + title: i18n.translate('xpack.ml.dataframe.analytics.rocChartSpec.yAxisTitle', { + defaultMessage: 'True Positive Rate (TPR) (a.k.a Recall)', + }), + type: 'quantitative', + axis: { + tickColor: GRAY, + labelColor: GRAY, + domainColor: GRAY, + titleColor: GRAY, + }, + }, + tooltip: [ + { type: LEGEND_TYPES.NOMINAL, field: 'class_name' }, + { type: LEGEND_TYPES.QUANTITATIVE, field: 'fpr' }, + { type: LEGEND_TYPES.QUANTITATIVE, field: 'tpr' }, + ], + }, + height: SIZE, + width: SIZE, + mark: 'line', + }; +}; diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/is_training_filter.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/is_training_filter.ts new file mode 100644 index 00000000000000..21203f85bbe849 --- /dev/null +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/is_training_filter.ts @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { isResultsSearchBoolQuery, ResultsSearchQuery } from '../../../../common/analytics'; + +export type IsTraining = boolean | undefined; + +export function isTrainingFilter( + searchQuery: ResultsSearchQuery, + resultsField: string +): IsTraining { + let isTraining: IsTraining; + const query = + isResultsSearchBoolQuery(searchQuery) && (searchQuery.bool.should || searchQuery.bool.filter); + + if (query !== undefined && query !== false) { + for (let i = 0; i < query.length; i++) { + const clause = query[i]; + + if (clause.match && clause.match[`${resultsField}.is_training`] !== undefined) { + isTraining = clause.match[`${resultsField}.is_training`]; + break; + } else if ( + clause.bool && + (clause.bool.should !== undefined || clause.bool.filter !== undefined) + ) { + const innerQuery = clause.bool.should || clause.bool.filter; + if (innerQuery !== undefined) { + for (let j = 0; j < innerQuery.length; j++) { + const innerClause = innerQuery[j]; + if ( + innerClause.match && + innerClause.match[`${resultsField}.is_training`] !== undefined + ) { + isTraining = innerClause.match[`${resultsField}.is_training`]; + break; + } + } + } + } + } + } + + return isTraining; +} diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/use_confusion_matrix.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/use_confusion_matrix.ts new file mode 100644 index 00000000000000..be44a8e36ed009 --- /dev/null +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/use_confusion_matrix.ts @@ -0,0 +1,98 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { useState, useEffect } from 'react'; + +import { + isClassificationEvaluateResponse, + ConfusionMatrix, + ResultsSearchQuery, + ANALYSIS_CONFIG_TYPE, +} from '../../../../common/analytics'; +import { isKeywordAndTextType } from '../../../../common/fields'; + +import { + getDependentVar, + getPredictionFieldName, + loadEvalData, + loadDocsCount, + DataFrameAnalyticsConfig, +} from '../../../../common'; + +import { isTrainingFilter } from './is_training_filter'; + +export const useConfusionMatrix = ( + jobConfig: DataFrameAnalyticsConfig, + searchQuery: ResultsSearchQuery +) => { + const [confusionMatrixData, setConfusionMatrixData] = useState([]); + const [isLoading, setIsLoading] = useState(false); + const [docsCount, setDocsCount] = useState(null); + const [error, setError] = useState(null); + + useEffect(() => { + async function loadConfusionMatrixData() { + setIsLoading(true); + + let requiresKeyword = false; + const dependentVariable = getDependentVar(jobConfig.analysis); + const resultsField = jobConfig.dest.results_field; + const isTraining = isTrainingFilter(searchQuery, resultsField); + + try { + requiresKeyword = isKeywordAndTextType(dependentVariable); + } catch (e) { + // Additional error handling due to missing field type is handled by loadEvalData + console.error('Unable to load new field types', e); // eslint-disable-line no-console + } + + const evalData = await loadEvalData({ + isTraining, + index: jobConfig.dest.index, + dependentVariable, + resultsField, + predictionFieldName: getPredictionFieldName(jobConfig.analysis), + searchQuery, + jobType: ANALYSIS_CONFIG_TYPE.CLASSIFICATION, + requiresKeyword, + }); + + const docsCountResp = await loadDocsCount({ + isTraining, + searchQuery, + resultsField, + destIndex: jobConfig.dest.index, + }); + + if ( + evalData.success === true && + evalData.eval && + isClassificationEvaluateResponse(evalData.eval) + ) { + const confusionMatrix = + evalData.eval?.classification?.multiclass_confusion_matrix?.confusion_matrix; + setError(null); + setConfusionMatrixData(confusionMatrix || []); + setIsLoading(false); + } else { + setIsLoading(false); + setConfusionMatrixData([]); + setError(evalData.error); + } + + if (docsCountResp.success === true) { + setDocsCount(docsCountResp.docsCount); + } else { + setDocsCount(null); + } + } + + loadConfusionMatrixData(); + }, [JSON.stringify([jobConfig, searchQuery])]); + + return { confusionMatrixData, docsCount, error, isLoading }; +}; diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/use_roc_curve.ts b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/use_roc_curve.ts new file mode 100644 index 00000000000000..8cdb6f86ebddab --- /dev/null +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/classification_exploration/use_roc_curve.ts @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { useState, useEffect } from 'react'; + +import { + isClassificationEvaluateResponse, + ResultsSearchQuery, + RocCurveItem, + ANALYSIS_CONFIG_TYPE, +} from '../../../../common/analytics'; +import { isKeywordAndTextType } from '../../../../common/fields'; + +import { + getDependentVar, + getPredictionFieldName, + loadEvalData, + DataFrameAnalyticsConfig, +} from '../../../../common'; + +import { ACTUAL_CLASS_ID, OTHER_CLASS_ID } from './column_data'; + +import { isTrainingFilter } from './is_training_filter'; + +interface RocCurveDataRow extends RocCurveItem { + class_name: string; +} + +export const useRocCurve = ( + jobConfig: DataFrameAnalyticsConfig, + searchQuery: ResultsSearchQuery, + visibleColumns: string[] +) => { + const classificationClasses = visibleColumns.filter( + (d) => d !== ACTUAL_CLASS_ID && d !== OTHER_CLASS_ID + ); + + const [rocCurveData, setRocCurveData] = useState([]); + const [isLoading, setIsLoading] = useState(false); + const [error, setError] = useState(null); + + useEffect(() => { + async function loadRocCurveData() { + setIsLoading(true); + + const dependentVariable = getDependentVar(jobConfig.analysis); + const resultsField = jobConfig.dest.results_field; + + const newRocCurveData: RocCurveDataRow[] = []; + + let requiresKeyword = false; + const errors: string[] = []; + + try { + requiresKeyword = isKeywordAndTextType(dependentVariable); + } catch (e) { + // Additional error handling due to missing field type is handled by loadEvalData + console.error('Unable to load new field types', e); // eslint-disable-line no-console + } + + for (let i = 0; i < classificationClasses.length; i++) { + const rocCurveClassName = classificationClasses[i]; + const evalData = await loadEvalData({ + isTraining: isTrainingFilter(searchQuery, resultsField), + index: jobConfig.dest.index, + dependentVariable, + resultsField, + predictionFieldName: getPredictionFieldName(jobConfig.analysis), + searchQuery, + jobType: ANALYSIS_CONFIG_TYPE.CLASSIFICATION, + requiresKeyword, + rocCurveClassName, + includeMulticlassConfusionMatrix: false, + }); + + if ( + evalData.success === true && + evalData.eval && + isClassificationEvaluateResponse(evalData.eval) + ) { + const auc = evalData.eval?.classification?.auc_roc?.value || 0; + const rocCurveDataForClass = (evalData.eval?.classification?.auc_roc?.curve || []).map( + (d) => ({ + class_name: `${rocCurveClassName} (AUC: ${Math.round(auc * 100000) / 100000})`, + ...d, + }) + ); + newRocCurveData.push(...rocCurveDataForClass); + } else if (evalData.error !== null) { + errors.push(evalData.error); + } + } + + setError(errors.length > 0 ? errors : null); + setRocCurveData(newRocCurveData); + setIsLoading(false); + } + + loadRocCurveData(); + }, [JSON.stringify([jobConfig, searchQuery, visibleColumns])]); + + return { rocCurveData, classificationClasses, error, isLoading }; +}; diff --git a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/error_callout/error_callout.tsx b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/error_callout/error_callout.tsx index d18e5b55794b52..81f5e535708092 100644 --- a/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/error_callout/error_callout.tsx +++ b/x-pack/plugins/ml/public/application/data_frame_analytics/pages/analytics_exploration/components/error_callout/error_callout.tsx @@ -10,7 +10,7 @@ import { i18n } from '@kbn/i18n'; import { EuiCallOut } from '@elastic/eui'; interface Props { - error: string; + error: string | JSX.Element; } export const ErrorCallout: FC = ({ error }) => { @@ -26,7 +26,7 @@ export const ErrorCallout: FC = ({ error }) => { ); // Job was created but not started so the destination index has not been created - if (error.includes('index_not_found')) { + if (typeof error === 'string' && error.includes('index_not_found')) { errorCallout = ( = ({ error }) => {

); - } else if (error.includes('No documents found')) { + } else if (typeof error === 'string' && error.includes('No documents found')) { // Job was started but no results have been written yet errorCallout = ( = ({ error }) => {

); - } else if (error.includes('userProvidedQueryBuilder')) { + } else if (typeof error === 'string' && error.includes('userProvidedQueryBuilder')) { // query bar syntax is incorrect errorCallout = ( = ({ )} + {isLoadingJobConfig === true && jobConfig === undefined && } + {isLoadingJobConfig === false && jobConfig !== undefined && isInitialized === true && ( + + )} + {isLoadingJobConfig === true && jobConfig !== undefined && totalFeatureImportance === undefined && } @@ -191,10 +196,7 @@ export const ExplorationPageWrapper: FC = ({ )} - {isLoadingJobConfig === true && jobConfig === undefined && } - {isLoadingJobConfig === false && jobConfig !== undefined && isInitialized === true && ( - - )} + {isLoadingJobConfig === true && jobConfig === undefined && } {isLoadingJobConfig === false && diff --git a/x-pack/test/functional/apps/ml/data_frame_analytics/classification_creation.ts b/x-pack/test/functional/apps/ml/data_frame_analytics/classification_creation.ts index 59f1775bb21177..1d67408b733603 100644 --- a/x-pack/test/functional/apps/ml/data_frame_analytics/classification_creation.ts +++ b/x-pack/test/functional/apps/ml/data_frame_analytics/classification_creation.ts @@ -41,6 +41,15 @@ export default function ({ getService }: FtrProviderContext) { modelMemory: '60mb', createIndexPattern: true, expected: { + rocCurveColorState: [ + // background + { key: '#FFFFFF', value: 93 }, + // tick/grid/axis + { key: '#98A2B3', value: 1 }, + { key: '#DDDDDD', value: 3 }, + // line + { key: '#6092C0', value: 1 }, + ], scatterplotMatrixColorStats: [ // background { key: '#000000', value: 94 }, @@ -102,7 +111,7 @@ export default function ({ getService }: FtrProviderContext) { await ml.dataFrameAnalyticsCreation.assertIncludeFieldsSelectionExists(); await ml.testExecution.logTestStep('displays the scatterplot matrix'); - await ml.dataFrameAnalyticsScatterplot.assertScatterplotMatrix( + await ml.dataFrameAnalyticsCanvasElement.assertCanvasElement( 'mlAnalyticsCreateJobWizardScatterplotMatrixFormRow', testData.expected.scatterplotMatrixColorStats ); @@ -221,11 +230,15 @@ export default function ({ getService }: FtrProviderContext) { await ml.testExecution.logTestStep('displays the results view for created job'); await ml.dataFrameAnalyticsTable.openResultsView(testData.jobId); await ml.dataFrameAnalyticsResults.assertClassificationEvaluatePanelElementsExists(); + await ml.dataFrameAnalyticsCanvasElement.assertCanvasElement( + 'mlDFAnalyticsClassificationExplorationRocCurveChart', + testData.expected.rocCurveColorState + ); await ml.dataFrameAnalyticsResults.assertClassificationTablePanelExists(); await ml.dataFrameAnalyticsResults.assertResultsTableExists(); await ml.dataFrameAnalyticsResults.assertResultsTableTrainingFiltersExist(); await ml.dataFrameAnalyticsResults.assertResultsTableNotEmpty(); - await ml.dataFrameAnalyticsScatterplot.assertScatterplotMatrix( + await ml.dataFrameAnalyticsCanvasElement.assertCanvasElement( 'mlDFExpandableSection-splom', testData.expected.scatterplotMatrixColorStats ); diff --git a/x-pack/test/functional/apps/ml/data_frame_analytics/outlier_detection_creation.ts b/x-pack/test/functional/apps/ml/data_frame_analytics/outlier_detection_creation.ts index 02535f158ee638..8b291fa36867a8 100644 --- a/x-pack/test/functional/apps/ml/data_frame_analytics/outlier_detection_creation.ts +++ b/x-pack/test/functional/apps/ml/data_frame_analytics/outlier_detection_creation.ts @@ -128,7 +128,7 @@ export default function ({ getService }: FtrProviderContext) { await ml.dataFrameAnalyticsCreation.assertIncludeFieldsSelectionExists(); await ml.testExecution.logTestStep('displays the scatterplot matrix'); - await ml.dataFrameAnalyticsScatterplot.assertScatterplotMatrix( + await ml.dataFrameAnalyticsCanvasElement.assertCanvasElement( 'mlAnalyticsCreateJobWizardScatterplotMatrixFormRow', testData.expected.scatterplotMatrixColorStatsWizard ); @@ -249,7 +249,7 @@ export default function ({ getService }: FtrProviderContext) { await ml.dataFrameAnalyticsResults.assertOutlierTablePanelExists(); await ml.dataFrameAnalyticsResults.assertResultsTableExists(); await ml.dataFrameAnalyticsResults.assertResultsTableNotEmpty(); - await ml.dataFrameAnalyticsScatterplot.assertScatterplotMatrix( + await ml.dataFrameAnalyticsCanvasElement.assertCanvasElement( 'mlDFExpandableSection-splom', testData.expected.scatterplotMatrixColorStatsResults ); diff --git a/x-pack/test/functional/apps/ml/data_frame_analytics/regression_creation.ts b/x-pack/test/functional/apps/ml/data_frame_analytics/regression_creation.ts index f41944e3409d76..4ce5d5b352e141 100644 --- a/x-pack/test/functional/apps/ml/data_frame_analytics/regression_creation.ts +++ b/x-pack/test/functional/apps/ml/data_frame_analytics/regression_creation.ts @@ -101,7 +101,7 @@ export default function ({ getService }: FtrProviderContext) { await ml.dataFrameAnalyticsCreation.assertIncludeFieldsSelectionExists(); await ml.testExecution.logTestStep('displays the scatterplot matrix'); - await ml.dataFrameAnalyticsScatterplot.assertScatterplotMatrix( + await ml.dataFrameAnalyticsCanvasElement.assertCanvasElement( 'mlAnalyticsCreateJobWizardScatterplotMatrixFormRow', testData.expected.scatterplotMatrixColorStats ); @@ -224,7 +224,7 @@ export default function ({ getService }: FtrProviderContext) { await ml.dataFrameAnalyticsResults.assertResultsTableExists(); await ml.dataFrameAnalyticsResults.assertResultsTableTrainingFiltersExist(); await ml.dataFrameAnalyticsResults.assertResultsTableNotEmpty(); - await ml.dataFrameAnalyticsScatterplot.assertScatterplotMatrix( + await ml.dataFrameAnalyticsCanvasElement.assertCanvasElement( 'mlDFExpandableSection-splom', testData.expected.scatterplotMatrixColorStats ); diff --git a/x-pack/test/functional/services/ml/data_frame_analytics_scatterplot.ts b/x-pack/test/functional/services/ml/data_frame_analytics_canvas_element.ts similarity index 72% rename from x-pack/test/functional/services/ml/data_frame_analytics_scatterplot.ts rename to x-pack/test/functional/services/ml/data_frame_analytics_canvas_element.ts index 39b387e2de650c..a354e0723d3772 100644 --- a/x-pack/test/functional/services/ml/data_frame_analytics_scatterplot.ts +++ b/x-pack/test/functional/services/ml/data_frame_analytics_canvas_element.ts @@ -9,14 +9,14 @@ import expect from '@kbn/expect'; import { FtrProviderContext } from '../../ftr_provider_context'; -export function MachineLearningDataFrameAnalyticsScatterplotProvider({ +export function MachineLearningDataFrameAnalyticsCanvasElementProvider({ getService, }: FtrProviderContext) { const canvasElement = getService('canvasElement'); const testSubjects = getService('testSubjects'); - return new (class AnalyticsScatterplot { - public async assertScatterplotMatrix( + return new (class AnalyticsCanvasElement { + public async assertCanvasElement( dataTestSubj: string, expectedColorStats: Array<{ key: string; @@ -24,16 +24,15 @@ export function MachineLearningDataFrameAnalyticsScatterplotProvider({ }> ) { await testSubjects.existOrFail(dataTestSubj); - await testSubjects.existOrFail('mlScatterplotMatrix'); const actualColorStats = await canvasElement.getColorStats( - `[data-test-subj="mlScatterplotMatrix"] canvas`, + `[data-test-subj="${dataTestSubj}"] canvas`, expectedColorStats, 1 ); expect(actualColorStats.every((d) => d.withinTolerance)).to.eql( true, - `Color stats for scatterplot matrix should be within tolerance. Expected: '${JSON.stringify( + `Color stats for canvas element should be within tolerance. Expected: '${JSON.stringify( expectedColorStats )}' (got '${JSON.stringify(actualColorStats)}')` ); diff --git a/x-pack/test/functional/services/ml/data_frame_analytics_results.ts b/x-pack/test/functional/services/ml/data_frame_analytics_results.ts index b6aba13054f75d..c08e13cedaaa5a 100644 --- a/x-pack/test/functional/services/ml/data_frame_analytics_results.ts +++ b/x-pack/test/functional/services/ml/data_frame_analytics_results.ts @@ -32,6 +32,7 @@ export function MachineLearningDataFrameAnalyticsResultsProvider({ async assertClassificationEvaluatePanelElementsExists() { await testSubjects.existOrFail('mlDFExpandableSection-ClassificationEvaluation'); await testSubjects.existOrFail('mlDFAnalyticsClassificationExplorationConfusionMatrix'); + await testSubjects.existOrFail('mlDFAnalyticsClassificationExplorationRocCurveChart'); }, async assertClassificationTablePanelExists() { diff --git a/x-pack/test/functional/services/ml/index.ts b/x-pack/test/functional/services/ml/index.ts index 91d009316cf9e8..ceee1ba7dc1acf 100644 --- a/x-pack/test/functional/services/ml/index.ts +++ b/x-pack/test/functional/services/ml/index.ts @@ -18,7 +18,7 @@ import { MachineLearningDataFrameAnalyticsProvider } from './data_frame_analytic import { MachineLearningDataFrameAnalyticsCreationProvider } from './data_frame_analytics_creation'; import { MachineLearningDataFrameAnalyticsEditProvider } from './data_frame_analytics_edit'; import { MachineLearningDataFrameAnalyticsResultsProvider } from './data_frame_analytics_results'; -import { MachineLearningDataFrameAnalyticsScatterplotProvider } from './data_frame_analytics_scatterplot'; +import { MachineLearningDataFrameAnalyticsCanvasElementProvider } from './data_frame_analytics_canvas_element'; import { MachineLearningDataFrameAnalyticsMapProvider } from './data_frame_analytics_map'; import { MachineLearningDataFrameAnalyticsTableProvider } from './data_frame_analytics_table'; import { MachineLearningDataVisualizerProvider } from './data_visualizer'; @@ -66,7 +66,7 @@ export function MachineLearningProvider(context: FtrProviderContext) { const dataFrameAnalyticsResults = MachineLearningDataFrameAnalyticsResultsProvider(context); const dataFrameAnalyticsMap = MachineLearningDataFrameAnalyticsMapProvider(context); const dataFrameAnalyticsTable = MachineLearningDataFrameAnalyticsTableProvider(context); - const dataFrameAnalyticsScatterplot = MachineLearningDataFrameAnalyticsScatterplotProvider( + const dataFrameAnalyticsCanvasElement = MachineLearningDataFrameAnalyticsCanvasElementProvider( context ); @@ -113,7 +113,7 @@ export function MachineLearningProvider(context: FtrProviderContext) { dataFrameAnalyticsResults, dataFrameAnalyticsMap, dataFrameAnalyticsTable, - dataFrameAnalyticsScatterplot, + dataFrameAnalyticsCanvasElement, dataVisualizer, dataVisualizerFileBased, dataVisualizerIndexBased,