Skip to content

Commit

Permalink
[Enterprise Search] Replace model selection dropdown with list (#171436)
Browse files Browse the repository at this point in the history
## Summary

This PR replaces the model selection dropdown in the ML inference
pipeline configuration flyout with a cleaner selection list. The model
cards also contain fast deploy action buttons for promoted models
(ELSER, E5). The list is periodically updated.

Old:
![Screenshot 2023-11-16 at 12 31
50](https://github.com/elastic/kibana/assets/14224983/0b46f766-4423-4b70-be99-8cfe9fe26cfd)

New:
<img width="1442" alt="Screenshot 2023-11-30 at 15 13 46"
src="https://github.com/elastic/kibana/assets/14224983/fd439280-6dce-4973-b622-08ad3e34e665">

### Checklist

- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
- [x] Any UI touched in this PR is usable by keyboard only (learn more
about [keyboard accessibility](https://webaim.org/techniques/keyboard/))
- [ ] Any UI touched in this PR does not create any new axe failures
(run axe in browser:
[FF](https://addons.mozilla.org/en-US/firefox/addon/axe-devtools/),
[Chrome](https://chrome.google.com/webstore/detail/axe-web-accessibility-tes/lhdoppojpmngadmnindnejefpokejbdd?hl=en-US))
- [x] This renders correctly on smaller devices using a responsive
layout. (You can test this [in your
browser](https://www.browserstack.com/guide/responsive-testing-on-local-server))

---------

Co-authored-by: kibanamachine <42973632+kibanamachine@users.noreply.github.com>
  • Loading branch information
demjened and kibanamachine authored Dec 1, 2023
1 parent 1533f30 commit 2c4d0a3
Show file tree
Hide file tree
Showing 13 changed files with 942 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
*/

import { MlTrainedModelConfig, MlTrainedModelStats } from '@elastic/elasticsearch/lib/api/types';
import { BUILT_IN_MODEL_TAG } from '@kbn/ml-trained-models-utils';
import { BUILT_IN_MODEL_TAG, TRAINED_MODEL_TYPE } from '@kbn/ml-trained-models-utils';

import { MlInferencePipeline, TrainedModelState } from '../types/pipelines';

import {
generateMlInferencePipelineBody,
getMlModelTypesForModelConfig,
parseMlInferenceParametersFromPipeline,
parseModelState,
parseModelStateFromStats,
parseModelStateReasonFromStats,
} from '.';
Expand Down Expand Up @@ -265,8 +266,12 @@ describe('parseMlInferenceParametersFromPipeline', () => {
});

describe('parseModelStateFromStats', () => {
it('returns not deployed for undefined stats', () => {
expect(parseModelStateFromStats()).toEqual(TrainedModelState.NotDeployed);
it('returns Started for the lang_ident model', () => {
expect(
parseModelStateFromStats({
model_type: TRAINED_MODEL_TYPE.LANG_IDENT,
})
).toEqual(TrainedModelState.Started);
});
it('returns Started', () => {
expect(
Expand Down Expand Up @@ -315,6 +320,28 @@ describe('parseModelStateFromStats', () => {
});
});

describe('parseModelState', () => {
it('returns Started', () => {
expect(parseModelState('started')).toEqual(TrainedModelState.Started);
expect(parseModelState('fully_allocated')).toEqual(TrainedModelState.Started);
});
it('returns Starting', () => {
expect(parseModelState('starting')).toEqual(TrainedModelState.Starting);
expect(parseModelState('downloading')).toEqual(TrainedModelState.Starting);
expect(parseModelState('downloaded')).toEqual(TrainedModelState.Starting);
});
it('returns Stopping', () => {
expect(parseModelState('stopping')).toEqual(TrainedModelState.Stopping);
});
it('returns Failed', () => {
expect(parseModelState('failed')).toEqual(TrainedModelState.Failed);
});
it('returns NotDeployed for an unknown state', () => {
expect(parseModelState(undefined)).toEqual(TrainedModelState.NotDeployed);
expect(parseModelState('other_state')).toEqual(TrainedModelState.NotDeployed);
});
});

describe('parseModelStateReasonFromStats', () => {
it('returns reason from deployment_stats', () => {
const reason = 'This is the reason the model is in a failed state';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,18 @@ export const parseModelStateFromStats = (
modelTypes?.includes(TRAINED_MODEL_TYPE.LANG_IDENT)
)
return TrainedModelState.Started;
switch (model?.deployment_stats?.state) {

return parseModelState(model?.deployment_stats?.state);
};

export const parseModelState = (state?: string) => {
switch (state) {
case 'started':
case 'fully_allocated':
return TrainedModelState.Started;
case 'starting':
case 'downloading':
case 'downloaded':
return TrainedModelState.Starting;
case 'stopping':
return TrainedModelState.Stopping;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,12 @@ import {

import { i18n } from '@kbn/i18n';

import { IndexNameLogic } from '../../index_name_logic';
import { IndexViewLogic } from '../../index_view_logic';

import { EMPTY_PIPELINE_CONFIGURATION, MLInferenceLogic } from './ml_inference_logic';
import { MlModelSelectOption } from './model_select_option';
import { ModelSelect } from './model_select';
import { PipelineSelectOption } from './pipeline_select_option';
import { MODEL_REDACTED_VALUE, MODEL_SELECT_PLACEHOLDER, normalizeModelName } from './utils';

const MODEL_SELECT_PLACEHOLDER_VALUE = 'model_placeholder$$';
const PIPELINE_SELECT_PLACEHOLDER_VALUE = 'pipeline_placeholder$$';

const CREATE_NEW_TAB_NAME = i18n.translate(
Expand All @@ -55,32 +52,14 @@ export const ConfigurePipeline: React.FC = () => {
addInferencePipelineModal: { configuration },
formErrors,
existingInferencePipelines,
supportedMLModels,
} = useValues(MLInferenceLogic);
const { selectExistingPipeline, setInferencePipelineConfiguration } =
useActions(MLInferenceLogic);
const { ingestionMethod } = useValues(IndexViewLogic);
const { indexName } = useValues(IndexNameLogic);

const { existingPipeline, modelID, pipelineName, isPipelineNameUserSupplied } = configuration;
const { pipelineName } = configuration;

const nameError = formErrors.pipelineName !== undefined && pipelineName.length > 0;

const modelOptions: Array<EuiSuperSelectOption<string>> = [
{
disabled: true,
inputDisplay:
existingPipeline && pipelineName.length > 0
? MODEL_REDACTED_VALUE
: MODEL_SELECT_PLACEHOLDER,
value: MODEL_SELECT_PLACEHOLDER_VALUE,
},
...supportedMLModels.map((model) => ({
dropdownDisplay: <MlModelSelectOption model={model} />,
inputDisplay: model.model_id,
value: model.model_id,
})),
];
const pipelineOptions: Array<EuiSuperSelectOption<string>> = [
{
disabled: true,
Expand Down Expand Up @@ -161,26 +140,7 @@ export const ConfigurePipeline: React.FC = () => {
{ defaultMessage: 'Select a trained ML Model' }
)}
>
<EuiSuperSelect
data-telemetry-id={`entSearchContent-${ingestionMethod}-pipelines-configureInferencePipeline-selectTrainedModel`}
fullWidth
hasDividers
disabled={inputsDisabled}
itemLayoutAlign="top"
onChange={(value) =>
setInferencePipelineConfiguration({
...configuration,
inferenceConfig: undefined,
modelID: value,
fieldMappings: undefined,
pipelineName: isPipelineNameUserSupplied
? pipelineName
: indexName + '-' + normalizeModelName(value),
})
}
options={modelOptions}
valueOfSelected={modelID === '' ? MODEL_SELECT_PLACEHOLDER_VALUE : modelID}
/>
<ModelSelect />
</EuiFormRow>
</EuiForm>
</>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { setMockActions, setMockValues } from '../../../../../__mocks__/kea_logic';

import React from 'react';

import { shallow } from 'enzyme';

import { EuiSelectable } from '@elastic/eui';

import { ModelSelect } from './model_select';

const DEFAULT_VALUES = {
addInferencePipelineModal: {
configuration: {},
},
selectableModels: [
{
modelId: 'model_1',
},
{
modelId: 'model_2',
},
],
indexName: 'my-index',
};
const MOCK_ACTIONS = {
setInferencePipelineConfiguration: jest.fn(),
};

describe('ModelSelect', () => {
beforeEach(() => {
jest.clearAllMocks();
setMockValues({});
setMockActions(MOCK_ACTIONS);
});
it('renders model select with no options', () => {
setMockValues({
...DEFAULT_VALUES,
selectableModels: null,
});

const wrapper = shallow(<ModelSelect />);
expect(wrapper.find(EuiSelectable)).toHaveLength(1);
const selectable = wrapper.find(EuiSelectable);
expect(selectable.prop('options')).toEqual([]);
});
it('renders model select with options', () => {
setMockValues(DEFAULT_VALUES);

const wrapper = shallow(<ModelSelect />);
expect(wrapper.find(EuiSelectable)).toHaveLength(1);
const selectable = wrapper.find(EuiSelectable);
expect(selectable.prop('options')).toEqual([
{
modelId: 'model_1',
label: 'model_1',
},
{
modelId: 'model_2',
label: 'model_2',
},
]);
});
it('selects the chosen option', () => {
setMockValues({
...DEFAULT_VALUES,
addInferencePipelineModal: {
configuration: {
...DEFAULT_VALUES.addInferencePipelineModal.configuration,
modelID: 'model_2',
},
},
});

const wrapper = shallow(<ModelSelect />);
expect(wrapper.find(EuiSelectable)).toHaveLength(1);
const selectable = wrapper.find(EuiSelectable);
expect(selectable.prop('options')[1].checked).toEqual('on');
});
it('sets model ID on selecting an item and clears config', () => {
setMockValues(DEFAULT_VALUES);

const wrapper = shallow(<ModelSelect />);
expect(wrapper.find(EuiSelectable)).toHaveLength(1);
const selectable = wrapper.find(EuiSelectable);
selectable.simulate('change', [{ modelId: 'model_1' }, { modelId: 'model_2', checked: 'on' }]);
expect(MOCK_ACTIONS.setInferencePipelineConfiguration).toHaveBeenCalledWith(
expect.objectContaining({
inferenceConfig: undefined,
modelID: 'model_2',
fieldMappings: undefined,
})
);
});
it('generates pipeline name on selecting an item', () => {
setMockValues(DEFAULT_VALUES);

const wrapper = shallow(<ModelSelect />);
expect(wrapper.find(EuiSelectable)).toHaveLength(1);
const selectable = wrapper.find(EuiSelectable);
selectable.simulate('change', [{ modelId: 'model_1' }, { modelId: 'model_2', checked: 'on' }]);
expect(MOCK_ACTIONS.setInferencePipelineConfiguration).toHaveBeenCalledWith(
expect.objectContaining({
pipelineName: 'my-index-model_2',
})
);
});
it('does not generate pipeline name on selecting an item if it a name was supplied by the user', () => {
setMockValues({
...DEFAULT_VALUES,
addInferencePipelineModal: {
configuration: {
...DEFAULT_VALUES.addInferencePipelineModal.configuration,
pipelineName: 'user-pipeline',
isPipelineNameUserSupplied: true,
},
},
});

const wrapper = shallow(<ModelSelect />);
expect(wrapper.find(EuiSelectable)).toHaveLength(1);
const selectable = wrapper.find(EuiSelectable);
selectable.simulate('change', [{ modelId: 'model_1' }, { modelId: 'model_2', checked: 'on' }]);
expect(MOCK_ACTIONS.setInferencePipelineConfiguration).toHaveBeenCalledWith(
expect.objectContaining({
pipelineName: 'user-pipeline',
})
);
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import React from 'react';

import { useActions, useValues } from 'kea';

import { EuiSelectable, useIsWithinMaxBreakpoint } from '@elastic/eui';

import { MlModel } from '../../../../../../../common/types/ml';
import { IndexNameLogic } from '../../index_name_logic';
import { IndexViewLogic } from '../../index_view_logic';

import { MLInferenceLogic } from './ml_inference_logic';
import { ModelSelectLogic } from './model_select_logic';
import { ModelSelectOption, ModelSelectOptionProps } from './model_select_option';
import { normalizeModelName } from './utils';

export const ModelSelect: React.FC = () => {
const { indexName } = useValues(IndexNameLogic);
const { ingestionMethod } = useValues(IndexViewLogic);
const {
addInferencePipelineModal: { configuration },
} = useValues(MLInferenceLogic);
const { selectableModels, isLoading } = useValues(ModelSelectLogic);
const { setInferencePipelineConfiguration } = useActions(MLInferenceLogic);

const { modelID, pipelineName, isPipelineNameUserSupplied } = configuration;

const getModelSelectOptionProps = (models: MlModel[]): ModelSelectOptionProps[] =>
(models ?? []).map((model) => ({
...model,
label: model.modelId,
checked: model.modelId === modelID ? 'on' : undefined,
}));

const onChange = (options: ModelSelectOptionProps[]) => {
const selectedOption = options.find((option) => option.checked === 'on');
setInferencePipelineConfiguration({
...configuration,
inferenceConfig: undefined,
modelID: selectedOption?.modelId ?? '',
fieldMappings: undefined,
pipelineName: isPipelineNameUserSupplied
? pipelineName
: indexName + '-' + normalizeModelName(selectedOption?.modelId ?? ''),
});
};

const renderOption = (option: ModelSelectOptionProps) => <ModelSelectOption {...option} />;

return (
<EuiSelectable
data-telemetry-id={`entSearchContent-${ingestionMethod}-pipelines-configureInferencePipeline-selectTrainedModel`}
options={getModelSelectOptionProps(selectableModels)}
singleSelection="always"
listProps={{
bordered: true,
rowHeight: useIsWithinMaxBreakpoint('s') ? 180 : 90,
showIcons: false,
onFocusBadge: false,
}}
height={360}
onChange={onChange}
renderOption={renderOption}
isLoading={isLoading}
searchable
>
{(list, search) => (
<>
{search}
{list}
</>
)}
</EuiSelectable>
);
};
Loading

0 comments on commit 2c4d0a3

Please sign in to comment.