diff --git a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts index f5f7945376d95..78af0a862d302 100644 --- a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts +++ b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts @@ -6,7 +6,7 @@ */ import { MlTrainedModelConfig, MlTrainedModelStats } from '@elastic/elasticsearch/lib/api/types'; -import { BUILT_IN_MODEL_TAG } from '@kbn/ml-trained-models-utils'; +import { BUILT_IN_MODEL_TAG, TRAINED_MODEL_TYPE } from '@kbn/ml-trained-models-utils'; import { MlInferencePipeline, TrainedModelState } from '../types/pipelines'; @@ -14,6 +14,7 @@ import { generateMlInferencePipelineBody, getMlModelTypesForModelConfig, parseMlInferenceParametersFromPipeline, + parseModelState, parseModelStateFromStats, parseModelStateReasonFromStats, } from '.'; @@ -265,8 +266,12 @@ describe('parseMlInferenceParametersFromPipeline', () => { }); describe('parseModelStateFromStats', () => { - it('returns not deployed for undefined stats', () => { - expect(parseModelStateFromStats()).toEqual(TrainedModelState.NotDeployed); + it('returns Started for the lang_ident model', () => { + expect( + parseModelStateFromStats({ + model_type: TRAINED_MODEL_TYPE.LANG_IDENT, + }) + ).toEqual(TrainedModelState.Started); }); it('returns Started', () => { expect( @@ -315,6 +320,28 @@ describe('parseModelStateFromStats', () => { }); }); +describe('parseModelState', () => { + it('returns Started', () => { + expect(parseModelState('started')).toEqual(TrainedModelState.Started); + expect(parseModelState('fully_allocated')).toEqual(TrainedModelState.Started); + }); + it('returns Starting', () => { + expect(parseModelState('starting')).toEqual(TrainedModelState.Starting); + expect(parseModelState('downloading')).toEqual(TrainedModelState.Starting); + expect(parseModelState('downloaded')).toEqual(TrainedModelState.Starting); + }); + it('returns Stopping', () => { + expect(parseModelState('stopping')).toEqual(TrainedModelState.Stopping); + }); + it('returns Failed', () => { + expect(parseModelState('failed')).toEqual(TrainedModelState.Failed); + }); + it('returns NotDeployed for an unknown state', () => { + expect(parseModelState(undefined)).toEqual(TrainedModelState.NotDeployed); + expect(parseModelState('other_state')).toEqual(TrainedModelState.NotDeployed); + }); +}); + describe('parseModelStateReasonFromStats', () => { it('returns reason from deployment_stats', () => { const reason = 'This is the reason the model is in a failed state'; diff --git a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts index 95c6672df6928..5f56c1105b297 100644 --- a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts +++ b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts @@ -202,10 +202,18 @@ export const parseModelStateFromStats = ( modelTypes?.includes(TRAINED_MODEL_TYPE.LANG_IDENT) ) return TrainedModelState.Started; - switch (model?.deployment_stats?.state) { + + return parseModelState(model?.deployment_stats?.state); +}; + +export const parseModelState = (state?: string) => { + switch (state) { case 'started': + case 'fully_allocated': return TrainedModelState.Started; case 'starting': + case 'downloading': + case 'downloaded': return TrainedModelState.Starting; case 'stopping': return TrainedModelState.Stopping; diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/configure_pipeline.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/configure_pipeline.tsx index 444fd87ef4160..75a0269f643cf 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/configure_pipeline.tsx +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/configure_pipeline.tsx @@ -24,15 +24,12 @@ import { import { i18n } from '@kbn/i18n'; -import { IndexNameLogic } from '../../index_name_logic'; import { IndexViewLogic } from '../../index_view_logic'; import { EMPTY_PIPELINE_CONFIGURATION, MLInferenceLogic } from './ml_inference_logic'; -import { MlModelSelectOption } from './model_select_option'; +import { ModelSelect } from './model_select'; import { PipelineSelectOption } from './pipeline_select_option'; -import { MODEL_REDACTED_VALUE, MODEL_SELECT_PLACEHOLDER, normalizeModelName } from './utils'; -const MODEL_SELECT_PLACEHOLDER_VALUE = 'model_placeholder$$'; const PIPELINE_SELECT_PLACEHOLDER_VALUE = 'pipeline_placeholder$$'; const CREATE_NEW_TAB_NAME = i18n.translate( @@ -55,32 +52,14 @@ export const ConfigurePipeline: React.FC = () => { addInferencePipelineModal: { configuration }, formErrors, existingInferencePipelines, - supportedMLModels, } = useValues(MLInferenceLogic); const { selectExistingPipeline, setInferencePipelineConfiguration } = useActions(MLInferenceLogic); const { ingestionMethod } = useValues(IndexViewLogic); - const { indexName } = useValues(IndexNameLogic); - - const { existingPipeline, modelID, pipelineName, isPipelineNameUserSupplied } = configuration; + const { pipelineName } = configuration; const nameError = formErrors.pipelineName !== undefined && pipelineName.length > 0; - const modelOptions: Array> = [ - { - disabled: true, - inputDisplay: - existingPipeline && pipelineName.length > 0 - ? MODEL_REDACTED_VALUE - : MODEL_SELECT_PLACEHOLDER, - value: MODEL_SELECT_PLACEHOLDER_VALUE, - }, - ...supportedMLModels.map((model) => ({ - dropdownDisplay: , - inputDisplay: model.model_id, - value: model.model_id, - })), - ]; const pipelineOptions: Array> = [ { disabled: true, @@ -161,26 +140,7 @@ export const ConfigurePipeline: React.FC = () => { { defaultMessage: 'Select a trained ML Model' } )} > - - setInferencePipelineConfiguration({ - ...configuration, - inferenceConfig: undefined, - modelID: value, - fieldMappings: undefined, - pipelineName: isPipelineNameUserSupplied - ? pipelineName - : indexName + '-' + normalizeModelName(value), - }) - } - options={modelOptions} - valueOfSelected={modelID === '' ? MODEL_SELECT_PLACEHOLDER_VALUE : modelID} - /> + diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select.test.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select.test.tsx new file mode 100644 index 0000000000000..15fb492fae56d --- /dev/null +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select.test.tsx @@ -0,0 +1,136 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { setMockActions, setMockValues } from '../../../../../__mocks__/kea_logic'; + +import React from 'react'; + +import { shallow } from 'enzyme'; + +import { EuiSelectable } from '@elastic/eui'; + +import { ModelSelect } from './model_select'; + +const DEFAULT_VALUES = { + addInferencePipelineModal: { + configuration: {}, + }, + selectableModels: [ + { + modelId: 'model_1', + }, + { + modelId: 'model_2', + }, + ], + indexName: 'my-index', +}; +const MOCK_ACTIONS = { + setInferencePipelineConfiguration: jest.fn(), +}; + +describe('ModelSelect', () => { + beforeEach(() => { + jest.clearAllMocks(); + setMockValues({}); + setMockActions(MOCK_ACTIONS); + }); + it('renders model select with no options', () => { + setMockValues({ + ...DEFAULT_VALUES, + selectableModels: null, + }); + + const wrapper = shallow(); + expect(wrapper.find(EuiSelectable)).toHaveLength(1); + const selectable = wrapper.find(EuiSelectable); + expect(selectable.prop('options')).toEqual([]); + }); + it('renders model select with options', () => { + setMockValues(DEFAULT_VALUES); + + const wrapper = shallow(); + expect(wrapper.find(EuiSelectable)).toHaveLength(1); + const selectable = wrapper.find(EuiSelectable); + expect(selectable.prop('options')).toEqual([ + { + modelId: 'model_1', + label: 'model_1', + }, + { + modelId: 'model_2', + label: 'model_2', + }, + ]); + }); + it('selects the chosen option', () => { + setMockValues({ + ...DEFAULT_VALUES, + addInferencePipelineModal: { + configuration: { + ...DEFAULT_VALUES.addInferencePipelineModal.configuration, + modelID: 'model_2', + }, + }, + }); + + const wrapper = shallow(); + expect(wrapper.find(EuiSelectable)).toHaveLength(1); + const selectable = wrapper.find(EuiSelectable); + expect(selectable.prop('options')[1].checked).toEqual('on'); + }); + it('sets model ID on selecting an item and clears config', () => { + setMockValues(DEFAULT_VALUES); + + const wrapper = shallow(); + expect(wrapper.find(EuiSelectable)).toHaveLength(1); + const selectable = wrapper.find(EuiSelectable); + selectable.simulate('change', [{ modelId: 'model_1' }, { modelId: 'model_2', checked: 'on' }]); + expect(MOCK_ACTIONS.setInferencePipelineConfiguration).toHaveBeenCalledWith( + expect.objectContaining({ + inferenceConfig: undefined, + modelID: 'model_2', + fieldMappings: undefined, + }) + ); + }); + it('generates pipeline name on selecting an item', () => { + setMockValues(DEFAULT_VALUES); + + const wrapper = shallow(); + expect(wrapper.find(EuiSelectable)).toHaveLength(1); + const selectable = wrapper.find(EuiSelectable); + selectable.simulate('change', [{ modelId: 'model_1' }, { modelId: 'model_2', checked: 'on' }]); + expect(MOCK_ACTIONS.setInferencePipelineConfiguration).toHaveBeenCalledWith( + expect.objectContaining({ + pipelineName: 'my-index-model_2', + }) + ); + }); + it('does not generate pipeline name on selecting an item if it a name was supplied by the user', () => { + setMockValues({ + ...DEFAULT_VALUES, + addInferencePipelineModal: { + configuration: { + ...DEFAULT_VALUES.addInferencePipelineModal.configuration, + pipelineName: 'user-pipeline', + isPipelineNameUserSupplied: true, + }, + }, + }); + + const wrapper = shallow(); + expect(wrapper.find(EuiSelectable)).toHaveLength(1); + const selectable = wrapper.find(EuiSelectable); + selectable.simulate('change', [{ modelId: 'model_1' }, { modelId: 'model_2', checked: 'on' }]); + expect(MOCK_ACTIONS.setInferencePipelineConfiguration).toHaveBeenCalledWith( + expect.objectContaining({ + pipelineName: 'user-pipeline', + }) + ); + }); +}); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select.tsx new file mode 100644 index 0000000000000..86c91c483702f --- /dev/null +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select.tsx @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React from 'react'; + +import { useActions, useValues } from 'kea'; + +import { EuiSelectable, useIsWithinMaxBreakpoint } from '@elastic/eui'; + +import { MlModel } from '../../../../../../../common/types/ml'; +import { IndexNameLogic } from '../../index_name_logic'; +import { IndexViewLogic } from '../../index_view_logic'; + +import { MLInferenceLogic } from './ml_inference_logic'; +import { ModelSelectLogic } from './model_select_logic'; +import { ModelSelectOption, ModelSelectOptionProps } from './model_select_option'; +import { normalizeModelName } from './utils'; + +export const ModelSelect: React.FC = () => { + const { indexName } = useValues(IndexNameLogic); + const { ingestionMethod } = useValues(IndexViewLogic); + const { + addInferencePipelineModal: { configuration }, + } = useValues(MLInferenceLogic); + const { selectableModels, isLoading } = useValues(ModelSelectLogic); + const { setInferencePipelineConfiguration } = useActions(MLInferenceLogic); + + const { modelID, pipelineName, isPipelineNameUserSupplied } = configuration; + + const getModelSelectOptionProps = (models: MlModel[]): ModelSelectOptionProps[] => + (models ?? []).map((model) => ({ + ...model, + label: model.modelId, + checked: model.modelId === modelID ? 'on' : undefined, + })); + + const onChange = (options: ModelSelectOptionProps[]) => { + const selectedOption = options.find((option) => option.checked === 'on'); + setInferencePipelineConfiguration({ + ...configuration, + inferenceConfig: undefined, + modelID: selectedOption?.modelId ?? '', + fieldMappings: undefined, + pipelineName: isPipelineNameUserSupplied + ? pipelineName + : indexName + '-' + normalizeModelName(selectedOption?.modelId ?? ''), + }); + }; + + const renderOption = (option: ModelSelectOptionProps) => ; + + return ( + + {(list, search) => ( + <> + {search} + {list} + + )} + + ); +}; diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_logic.test.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_logic.test.ts new file mode 100644 index 0000000000000..b38efe00210a0 --- /dev/null +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_logic.test.ts @@ -0,0 +1,143 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { LogicMounter } from '../../../../../__mocks__/kea_logic'; + +import { MlModel, MlModelDeploymentState } from '../../../../../../../common/types/ml'; +import { CachedFetchModelsApiLogic } from '../../../../api/ml_models/cached_fetch_models_api_logic'; +import { + CreateModelApiLogic, + CreateModelResponse, +} from '../../../../api/ml_models/create_model_api_logic'; +import { StartModelApiLogic } from '../../../../api/ml_models/start_model_api_logic'; + +import { ModelSelectLogic } from './model_select_logic'; + +const CREATE_MODEL_API_RESPONSE: CreateModelResponse = { + modelId: 'model_1', + deploymentState: MlModelDeploymentState.NotDeployed, +}; +const FETCH_MODELS_API_DATA_RESPONSE: MlModel[] = [ + { + modelId: 'model_1', + title: 'Model 1', + type: 'ner', + deploymentState: MlModelDeploymentState.NotDeployed, + startTime: 0, + targetAllocationCount: 0, + nodeAllocationCount: 0, + threadsPerAllocation: 0, + isPlaceholder: false, + hasStats: false, + }, +]; + +describe('ModelSelectLogic', () => { + const { mount } = new LogicMounter(ModelSelectLogic); + const { mount: mountCreateModelApiLogic } = new LogicMounter(CreateModelApiLogic); + const { mount: mountCachedFetchModelsApiLogic } = new LogicMounter(CachedFetchModelsApiLogic); + const { mount: mountStartModelApiLogic } = new LogicMounter(StartModelApiLogic); + + beforeEach(() => { + jest.clearAllMocks(); + mountCreateModelApiLogic(); + mountCachedFetchModelsApiLogic(); + mountStartModelApiLogic(); + mount(); + }); + + describe('listeners', () => { + describe('createModel', () => { + it('creates the model', () => { + const modelId = 'model_1'; + jest.spyOn(ModelSelectLogic.actions, 'createModelMakeRequest'); + + ModelSelectLogic.actions.createModel(modelId); + + expect(ModelSelectLogic.actions.createModelMakeRequest).toHaveBeenCalledWith({ modelId }); + }); + }); + + describe('createModelSuccess', () => { + it('starts polling models', () => { + jest.spyOn(ModelSelectLogic.actions, 'startPollingModels'); + + ModelSelectLogic.actions.createModelSuccess(CREATE_MODEL_API_RESPONSE); + + expect(ModelSelectLogic.actions.startPollingModels).toHaveBeenCalled(); + }); + }); + + describe('fetchModels', () => { + it('makes fetch models request', () => { + jest.spyOn(ModelSelectLogic.actions, 'fetchModelsMakeRequest'); + + ModelSelectLogic.actions.fetchModels(); + + expect(ModelSelectLogic.actions.fetchModelsMakeRequest).toHaveBeenCalled(); + }); + }); + + describe('startModel', () => { + it('makes start model request', () => { + const modelId = 'model_1'; + jest.spyOn(ModelSelectLogic.actions, 'startModelMakeRequest'); + + ModelSelectLogic.actions.startModel(modelId); + + expect(ModelSelectLogic.actions.startModelMakeRequest).toHaveBeenCalledWith({ modelId }); + }); + }); + + describe('startModelSuccess', () => { + it('starts polling models', () => { + jest.spyOn(ModelSelectLogic.actions, 'startPollingModels'); + + ModelSelectLogic.actions.startModelSuccess(CREATE_MODEL_API_RESPONSE); + + expect(ModelSelectLogic.actions.startPollingModels).toHaveBeenCalled(); + }); + }); + }); + + describe('selectors', () => { + describe('areActionButtonsDisabled', () => { + it('is set to false if create and start APIs are idle', () => { + CreateModelApiLogic.actions.apiReset(); + StartModelApiLogic.actions.apiReset(); + + expect(ModelSelectLogic.values.areActionButtonsDisabled).toBe(false); + }); + it('is set to true if create API is making a request', () => { + CreateModelApiLogic.actions.makeRequest({ modelId: 'model_1' }); + + expect(ModelSelectLogic.values.areActionButtonsDisabled).toBe(true); + }); + it('is set to true if start API is making a request', () => { + StartModelApiLogic.actions.makeRequest({ modelId: 'model_1' }); + + expect(ModelSelectLogic.values.areActionButtonsDisabled).toBe(true); + }); + }); + + describe('selectableModels', () => { + it('gets models data from API response', () => { + CachedFetchModelsApiLogic.actions.apiSuccess(FETCH_MODELS_API_DATA_RESPONSE); + + expect(ModelSelectLogic.values.selectableModels).toEqual(FETCH_MODELS_API_DATA_RESPONSE); + }); + }); + + describe('isLoading', () => { + it('is set to true if the fetch API is loading the first time', () => { + CachedFetchModelsApiLogic.actions.apiReset(); + + expect(ModelSelectLogic.values.isLoading).toBe(true); + }); + }); + }); +}); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_logic.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_logic.ts new file mode 100644 index 0000000000000..9f8c2b8b97612 --- /dev/null +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_logic.ts @@ -0,0 +1,127 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { kea, MakeLogicType } from 'kea'; + +import { HttpError, Status } from '../../../../../../../common/types/api'; +import { MlModel } from '../../../../../../../common/types/ml'; +import { + CachedFetchModelsApiLogic, + CachedFetchModlesApiLogicActions, +} from '../../../../api/ml_models/cached_fetch_models_api_logic'; +import { + CreateModelApiLogic, + CreateModelApiLogicActions, +} from '../../../../api/ml_models/create_model_api_logic'; +import { FetchModelsApiResponse } from '../../../../api/ml_models/fetch_models_api_logic'; +import { + StartModelApiLogic, + StartModelApiLogicActions, +} from '../../../../api/ml_models/start_model_api_logic'; + +export interface ModelSelectActions { + createModel: (modelId: string) => { modelId: string }; + createModelMakeRequest: CreateModelApiLogicActions['makeRequest']; + createModelSuccess: CreateModelApiLogicActions['apiSuccess']; + + fetchModels: () => void; + fetchModelsMakeRequest: CachedFetchModlesApiLogicActions['makeRequest']; + fetchModelsError: CachedFetchModlesApiLogicActions['apiError']; + fetchModelsSuccess: CachedFetchModlesApiLogicActions['apiSuccess']; + startPollingModels: CachedFetchModlesApiLogicActions['startPolling']; + + startModel: (modelId: string) => { modelId: string }; + startModelMakeRequest: StartModelApiLogicActions['makeRequest']; + startModelSuccess: StartModelApiLogicActions['apiSuccess']; +} + +export interface ModelSelectValues { + areActionButtonsDisabled: boolean; + createModelError: HttpError | undefined; + createModelStatus: Status; + isLoading: boolean; + isInitialLoading: boolean; + modelsData: FetchModelsApiResponse | undefined; + modelsStatus: Status; + selectableModels: MlModel[]; + startModelError: HttpError | undefined; + startModelStatus: Status; +} + +export const ModelSelectLogic = kea>({ + actions: { + createModel: (modelId: string) => ({ modelId }), + fetchModels: true, + startModel: (modelId: string) => ({ modelId }), + }, + connect: { + actions: [ + CreateModelApiLogic, + [ + 'makeRequest as createModelMakeRequest', + 'apiSuccess as createModelSuccess', + 'apiError as createModelError', + ], + CachedFetchModelsApiLogic, + [ + 'makeRequest as fetchModelsMakeRequest', + 'apiSuccess as fetchModelsSuccess', + 'apiError as fetchModelsError', + 'startPolling as startPollingModels', + ], + StartModelApiLogic, + [ + 'makeRequest as startModelMakeRequest', + 'apiSuccess as startModelSuccess', + 'apiError as startModelError', + ], + ], + values: [ + CreateModelApiLogic, + ['status as createModelStatus', 'error as createModelError'], + CachedFetchModelsApiLogic, + ['modelsData', 'status as modelsStatus', 'isInitialLoading'], + StartModelApiLogic, + ['status as startModelStatus', 'error as startModelError'], + ], + }, + events: ({ actions }) => ({ + afterMount: () => { + actions.startPollingModels(); + }, + }), + listeners: ({ actions }) => ({ + createModel: ({ modelId }) => { + actions.createModelMakeRequest({ modelId }); + }, + createModelSuccess: () => { + actions.startPollingModels(); + }, + fetchModels: () => { + actions.fetchModelsMakeRequest({}); + }, + startModel: ({ modelId }) => { + actions.startModelMakeRequest({ modelId }); + }, + startModelSuccess: () => { + actions.startPollingModels(); + }, + }), + path: ['enterprise_search', 'content', 'model_select_logic'], + selectors: ({ selectors }) => ({ + areActionButtonsDisabled: [ + () => [selectors.createModelStatus, selectors.startModelStatus], + (createModelStatus: Status, startModelStatus: Status) => + createModelStatus === Status.LOADING || startModelStatus === Status.LOADING, + ], + selectableModels: [ + () => [selectors.modelsData], + (response: FetchModelsApiResponse) => response ?? [], + ], + isLoading: [() => [selectors.isInitialLoading], (isInitialLoading) => isInitialLoading], + }), +}); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_option.test.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_option.test.tsx new file mode 100644 index 0000000000000..411bb8947257c --- /dev/null +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_option.test.tsx @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { setMockValues } from '../../../../../__mocks__/kea_logic'; + +import React from 'react'; + +import { shallow } from 'enzyme'; + +import { EuiBadge, EuiText } from '@elastic/eui'; + +import { MlModelDeploymentState } from '../../../../../../../common/types/ml'; +import { TrainedModelHealth } from '../ml_model_health'; + +import { + DeployModelButton, + getContextMenuPanel, + ModelSelectOption, + ModelSelectOptionProps, + StartModelButton, +} from './model_select_option'; + +const DEFAULT_PROPS: ModelSelectOptionProps = { + modelId: 'model_1', + type: 'ner', + label: 'Model 1', + title: 'Model 1', + description: 'Model 1 description', + license: 'elastic', + deploymentState: MlModelDeploymentState.NotDeployed, + startTime: 0, + targetAllocationCount: 0, + nodeAllocationCount: 0, + threadsPerAllocation: 0, + isPlaceholder: false, + hasStats: false, +}; + +describe('ModelSelectOption', () => { + beforeEach(() => { + jest.clearAllMocks(); + setMockValues({}); + }); + it('renders with license badge if present', () => { + const wrapper = shallow(); + expect(wrapper.find(EuiBadge)).toHaveLength(1); + }); + it('renders without license badge if not present', () => { + const props = { + ...DEFAULT_PROPS, + license: undefined, + }; + + const wrapper = shallow(); + expect(wrapper.find(EuiBadge)).toHaveLength(0); + }); + it('renders with description if present', () => { + const wrapper = shallow(); + expect(wrapper.find(EuiText)).toHaveLength(1); + }); + it('renders without description if not present', () => { + const props = { + ...DEFAULT_PROPS, + description: undefined, + }; + + const wrapper = shallow(); + expect(wrapper.find(EuiText)).toHaveLength(0); + }); + it('renders deploy button for a model placeholder', () => { + const props = { + ...DEFAULT_PROPS, + isPlaceholder: true, + }; + + const wrapper = shallow(); + expect(wrapper.find(DeployModelButton)).toHaveLength(1); + }); + it('renders start button for a downloaded model', () => { + const props = { + ...DEFAULT_PROPS, + deploymentState: MlModelDeploymentState.Downloaded, + }; + + const wrapper = shallow(); + expect(wrapper.find(StartModelButton)).toHaveLength(1); + }); + it('renders status badge if there is no action button', () => { + const wrapper = shallow(); + expect(wrapper.find(TrainedModelHealth)).toHaveLength(1); + }); + + describe('getContextMenuPanel', () => { + it('gets model details link if URL is present', () => { + const panels = getContextMenuPanel('https://model.ai'); + expect(panels[0].items).toHaveLength(2); + }); + }); +}); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_option.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_option.tsx index a9efa40644540..3133dc6feb3bd 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_option.tsx +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/model_select_option.tsx @@ -5,55 +5,250 @@ * 2.0. */ -import React from 'react'; +import React, { useState } from 'react'; -import { EuiFlexGroup, EuiFlexItem, EuiTextColor, EuiTitle } from '@elastic/eui'; +import { useActions, useValues } from 'kea'; import { - getMlModelTypesForModelConfig, - parseModelStateFromStats, - parseModelStateReasonFromStats, -} from '../../../../../../../common/ml_inference_pipeline'; -import { TrainedModel } from '../../../../api/ml_models/ml_trained_models_logic'; -import { getMLType, getModelDisplayTitle } from '../../../shared/ml_inference/utils'; + EuiBadge, + EuiButton, + EuiButtonEmpty, + EuiButtonIcon, + EuiContextMenu, + EuiContextMenuPanelDescriptor, + EuiFlexGroup, + EuiFlexItem, + EuiPopover, + EuiRadio, + EuiText, + EuiTextColor, + EuiTitle, + useIsWithinMaxBreakpoint, +} from '@elastic/eui'; +import { i18n } from '@kbn/i18n'; + +import { MlModel, MlModelDeploymentState } from '../../../../../../../common/types/ml'; +import { KibanaLogic } from '../../../../../shared/kibana'; import { TrainedModelHealth } from '../ml_model_health'; -import { MLModelTypeBadge } from '../ml_model_type_badge'; - -export interface MlModelSelectOptionProps { - model: TrainedModel; -} -export const MlModelSelectOption: React.FC = ({ model }) => { - const type = getMLType(getMlModelTypesForModelConfig(model)); - const title = getModelDisplayTitle(type); + +import { ModelSelectLogic } from './model_select_logic'; +import { TRAINED_MODELS_PATH } from './utils'; + +export const getContextMenuPanel = ( + modelDetailsPageUrl?: string +): EuiContextMenuPanelDescriptor[] => { + return [ + { + id: 0, + items: [ + { + name: i18n.translate( + 'xpack.enterpriseSearch.content.indices.pipelines.modelSelectOption.actionMenu.tuneModelPerformance.label', + { + defaultMessage: 'Tune model performance', + } + ), + icon: 'controlsHorizontal', + onClick: () => + KibanaLogic.values.navigateToUrl(TRAINED_MODELS_PATH, { + shouldNotCreateHref: true, + }), + }, + ...(modelDetailsPageUrl + ? [ + { + name: i18n.translate( + 'xpack.enterpriseSearch.content.indices.pipelines.modelSelectOption.actionMenu.modelDetails.label', + { + defaultMessage: 'Model details', + } + ), + icon: 'popout', + href: modelDetailsPageUrl, + target: '_blank', + }, + ] + : []), + ], + }, + ]; +}; + +export type ModelSelectOptionProps = MlModel & { + label: string; + checked?: 'on'; +}; + +export const DeployModelButton: React.FC<{ onClick: () => void; disabled: boolean }> = ({ + onClick, + disabled, +}) => { return ( - - - -

{title ?? model.model_id}

-
+ + {i18n.translate( + 'xpack.enterpriseSearch.content.indices.pipelines.modelSelectOption.deployButton.label', + { + defaultMessage: 'Deploy', + } + )} + + ); +}; + +export const StartModelButton: React.FC<{ onClick: () => void; disabled: boolean }> = ({ + onClick, + disabled, +}) => { + return ( + + {i18n.translate( + 'xpack.enterpriseSearch.content.indices.pipelines.modelSelectOption.startButton.label', + { + defaultMessage: 'Start', + } + )} + + ); +}; + +export const ModelMenuPopover: React.FC<{ + onClick: () => void; + closePopover: () => void; + isOpen: boolean; + modelDetailsPageUrl?: string; +}> = ({ onClick, closePopover, isOpen, modelDetailsPageUrl }) => { + return ( + + } + isOpen={isOpen} + closePopover={closePopover} + anchorPosition="leftCenter" + panelPaddingSize="none" + > + + + ); +}; + +export const ModelSelectOption: React.FC = ({ + modelId, + title, + description, + license, + deploymentState, + deploymentStateReason, + modelDetailsPageUrl, + isPlaceholder, + checked, +}) => { + const [isPopoverOpen, setIsPopoverOpen] = useState(false); + const onMenuButtonClick = () => setIsPopoverOpen((isOpen) => !isOpen); + const closePopover = () => setIsPopoverOpen(false); + + const { createModel, startModel } = useActions(ModelSelectLogic); + const { areActionButtonsDisabled } = useValues(ModelSelectLogic); + + return ( + + {/* Selection radio button */} + + null} + // @ts-ignore + inert + /> - - - {title && ( + {/* Title, model ID, description, license */} + + + + +

{title}

+
+
+ + {modelId} + + {(license || description) && ( - {model.model_id} + + {license && ( + + {/* Wrap in a div to prevent the badge from growing to a whole row on mobile */} +
+ + {i18n.translate( + 'xpack.enterpriseSearch.content.indices.pipelines.modelSelectOption.licenseBadge.label', + { + defaultMessage: 'License: {license}', + values: { + license, + }, + } + )} + +
+
+ )} + {description && ( + + +
+ {description} +
+
+
+ )} +
)} - +
+
+ {/* Status indicator OR action button */} + + {/* Wrap in a div to prevent the badge/button from growing to a whole row on mobile */} +
+ {isPlaceholder ? ( + createModel(modelId)} + disabled={areActionButtonsDisabled} + /> + ) : deploymentState === MlModelDeploymentState.Downloaded ? ( + startModel(modelId)} + disabled={areActionButtonsDisabled} + /> + ) : ( - - - - - - - - - + )} +
+
+ {/* Actions menu */} + +
); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_model_health.test.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_model_health.test.tsx index 47136ff90f799..65bfbc0951d30 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_model_health.test.tsx +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_model_health.test.tsx @@ -13,6 +13,7 @@ import { shallow } from 'enzyme'; import { EuiHealth } from '@elastic/eui'; +import { MlModelDeploymentState } from '../../../../../../common/types/ml'; import { InferencePipeline, TrainedModelState } from '../../../../../../common/types/pipelines'; import { TrainedModelHealth } from './ml_model_health'; @@ -30,6 +31,18 @@ describe('TrainedModelHealth', () => { pipelineReferences: [], types: ['pytorch'], }; + it('renders model downloading', () => { + const wrapper = shallow(); + const health = wrapper.find(EuiHealth); + expect(health.prop('children')).toEqual('Downloading'); + expect(health.prop('color')).toEqual('warning'); + }); + it('renders model downloaded', () => { + const wrapper = shallow(); + const health = wrapper.find(EuiHealth); + expect(health.prop('children')).toEqual('Downloaded'); + expect(health.prop('color')).toEqual('subdued'); + }); it('renders model started', () => { const pipeline: InferencePipeline = { ...commonModelData, diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_model_health.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_model_health.tsx index 45fd54b6bf4fd..133582520deb8 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_model_health.tsx +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_model_health.tsx @@ -12,8 +12,33 @@ import { EuiHealth, EuiToolTip } from '@elastic/eui'; import { i18n } from '@kbn/i18n'; import { FormattedMessage } from '@kbn/i18n-react'; +import { MlModelDeploymentState } from '../../../../../../common/types/ml'; import { TrainedModelState } from '../../../../../../common/types/pipelines'; +const modelDownloadingText = i18n.translate( + 'xpack.enterpriseSearch.inferencePipelineCard.modelState.downloading', + { + defaultMessage: 'Downloading', + } +); +const modelDownloadingTooltip = i18n.translate( + 'xpack.enterpriseSearch.inferencePipelineCard.modelState.downloading.tooltip', + { + defaultMessage: 'This trained model is downloading', + } +); +const modelDownloadedText = i18n.translate( + 'xpack.enterpriseSearch.inferencePipelineCard.modelState.downloaded', + { + defaultMessage: 'Downloaded', + } +); +const modelDownloadedTooltip = i18n.translate( + 'xpack.enterpriseSearch.inferencePipelineCard.modelState.downloaded.tooltip', + { + defaultMessage: 'This trained model is downloaded and can be started', + } +); const modelStartedText = i18n.translate( 'xpack.enterpriseSearch.inferencePipelineCard.modelState.started', { @@ -73,7 +98,7 @@ const modelNotDeployedTooltip = i18n.translate( ); export interface TrainedModelHealthProps { - modelState: TrainedModelState; + modelState: TrainedModelState | MlModelDeploymentState; modelStateReason?: string; } @@ -87,27 +112,52 @@ export const TrainedModelHealth: React.FC = ({ tooltipText: React.ReactNode; }; switch (modelState) { - case TrainedModelState.Started: + case TrainedModelState.NotDeployed: + case MlModelDeploymentState.NotDeployed: modelHealth = { - healthColor: 'success', - healthText: modelStartedText, - tooltipText: modelStartedTooltip, + healthColor: 'danger', + healthText: modelNotDeployedText, + tooltipText: modelNotDeployedTooltip, }; break; - case TrainedModelState.Stopping: + case MlModelDeploymentState.Downloading: modelHealth = { healthColor: 'warning', - healthText: modelStoppingText, - tooltipText: modelStoppingTooltip, + healthText: modelDownloadingText, + tooltipText: modelDownloadingTooltip, + }; + break; + case MlModelDeploymentState.Downloaded: + modelHealth = { + healthColor: 'subdued', + healthText: modelDownloadedText, + tooltipText: modelDownloadedTooltip, }; break; case TrainedModelState.Starting: + case MlModelDeploymentState.Starting: modelHealth = { healthColor: 'warning', healthText: modelStartingText, tooltipText: modelStartingTooltip, }; break; + case TrainedModelState.Started: + case MlModelDeploymentState.Started: + case MlModelDeploymentState.FullyAllocated: + modelHealth = { + healthColor: 'success', + healthText: modelStartedText, + tooltipText: modelStartedTooltip, + }; + break; + case TrainedModelState.Stopping: + modelHealth = { + healthColor: 'warning', + healthText: modelStoppingText, + tooltipText: modelStoppingTooltip, + }; + break; case TrainedModelState.Failed: modelHealth = { healthColor: 'danger', @@ -133,7 +183,7 @@ export const TrainedModelHealth: React.FC = ({ ), }; break; - case TrainedModelState.NotDeployed: + default: modelHealth = { healthColor: 'danger', healthText: modelNotDeployedText, diff --git a/x-pack/plugins/enterprise_search/server/lib/ml/start_ml_model_deployment.ts b/x-pack/plugins/enterprise_search/server/lib/ml/start_ml_model_deployment.ts index 4f65dbf9ced64..becd34a6c3c95 100644 --- a/x-pack/plugins/enterprise_search/server/lib/ml/start_ml_model_deployment.ts +++ b/x-pack/plugins/enterprise_search/server/lib/ml/start_ml_model_deployment.ts @@ -43,7 +43,7 @@ export const startMlModelDeployment = async ( // we're downloaded already, but not deployed yet - let's deploy it const startRequest: MlStartTrainedModelDeploymentRequest = { model_id: modelName, - wait_for: 'started', + wait_for: 'starting', }; await trainedModelsProvider.startTrainedModelDeployment(startRequest); diff --git a/x-pack/plugins/enterprise_search/server/lib/ml/utils.ts b/x-pack/plugins/enterprise_search/server/lib/ml/utils.ts index 29aded727280d..cc81b78c3ff09 100644 --- a/x-pack/plugins/enterprise_search/server/lib/ml/utils.ts +++ b/x-pack/plugins/enterprise_search/server/lib/ml/utils.ts @@ -64,12 +64,11 @@ export const ELSER_MODEL_PLACEHOLDER: MlModel = { ...BASE_MODEL, modelId: ELSER_MODEL_ID, type: SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION, - title: 'Elastic Learned Sparse EncodeR (ELSER)', + title: 'ELSER (Elastic Learned Sparse EncodeR)', description: i18n.translate('xpack.enterpriseSearch.modelCard.elserPlaceholder.description', { defaultMessage: - 'ELSER is designed to efficiently use context in natural language queries with better results than BM25 alone.', + "ELSER is Elastic's NLP model for English semantic search, utilizing sparse vectors. It prioritizes intent and contextual meaning over literal term matching, optimized specifically for English documents and queries on the Elastic platform.", }), - license: 'Elastic', isPlaceholder: true, }; @@ -77,9 +76,10 @@ export const E5_MODEL_PLACEHOLDER: MlModel = { ...BASE_MODEL, modelId: E5_MODEL_ID, type: SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING, - title: 'E5 Multilingual Embedding', + title: 'E5 (EmbEddings from bidirEctional Encoder rEpresentations)', description: i18n.translate('xpack.enterpriseSearch.modelCard.e5Placeholder.description', { - defaultMessage: 'Multilingual dense vector embedding generator.', + defaultMessage: + 'E5 is an NLP model that enables you to perform multi-lingual semantic search by using dense vector representations. This model performs best for non-English language documents and queries.', }), license: 'MIT', modelDetailsPageUrl: 'https://huggingface.co/intfloat/multilingual-e5-small',