Skip to content

Commit

Permalink
[ML] Support E5 model download from the list view, mark not supported…
Browse files Browse the repository at this point in the history
…/not optimized models (#189372)

## Summary

- Adds the "Show all" switch at the top of the page (off by default)
that responsible for hiding unsupported/not optimised model versions
- Adds a download action for E5 models 
- Removes the toast notification on successful start of a model download
- Adds a warning icon with an appropriate tooltip text if model version
is not supported by the cluster architecture or not optimised for it

<img width="1401" alt="image"
src="https://github.com/user-attachments/assets/30748315-3018-46e2-91dd-fd5128b0695e">



### Checklist


- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
  • Loading branch information
darnautov authored Aug 7, 2024
1 parent 199d0f6 commit 514db54
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ export const BUILT_IN_MODEL_TAG = 'prepackaged';

export const ELASTIC_MODEL_TAG = 'elastic';

export const ELASTIC_MODEL_DEFINITIONS: Record<string, ModelDefinition> = Object.freeze({
export const ELASTIC_MODEL_DEFINITIONS: Record<
string,
Omit<ModelDefinition, 'supported'>
> = Object.freeze({
[ELSER_ID_V1]: {
modelName: 'elser',
hidden: true,
Expand Down Expand Up @@ -156,6 +159,8 @@ export interface ModelDefinition {
default?: boolean;
/** Indicates if model version is recommended for deployment based on the cluster configuration */
recommended?: boolean;
/** Indicates if model version is supported by the cluster */
supported: boolean;
hidden?: boolean;
/** Software license of a model, e.g. MIT */
license?: string;
Expand Down
1 change: 1 addition & 0 deletions x-pack/plugins/ml/common/types/common.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export interface ListingPageUrlState {
sortField: string;
sortDirection: string;
queryText?: string;
showAll?: boolean;
}

export type AppPageState<T> = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@ import {
DEPLOYMENT_STATE,
TRAINED_MODEL_TYPE,
} from '@kbn/ml-trained-models-utils';
import {
ELASTIC_MODEL_TAG,
MODEL_STATE,
} from '@kbn/ml-trained-models-utils/src/constants/trained_models';
import { MODEL_STATE } from '@kbn/ml-trained-models-utils/src/constants/trained_models';
import {
getAnalysisType,
type DataFrameAnalysisConfigType,
Expand Down Expand Up @@ -409,10 +406,7 @@ export function useModelActions({
icon: 'download',
type: 'icon',
isPrimary: true,
available: (item) =>
canCreateTrainedModels &&
item.tags.includes(ELASTIC_MODEL_TAG) &&
item.state === MODEL_STATE.NOT_DOWNLOADED,
available: (item) => canCreateTrainedModels && item.state === MODEL_STATE.NOT_DOWNLOADED,
enabled: (item) => !isLoading,
onClick: async (item) => {
onModelDownloadRequest(item.model_id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ import {
EuiFlexGroup,
EuiFlexItem,
EuiHealth,
EuiIcon,
EuiInMemoryTable,
EuiLink,
type EuiSearchBarProps,
EuiProgress,
EuiSpacer,
EuiSwitch,
EuiTitle,
EuiToolTip,
EuiProgress,
type EuiSearchBarProps,
} from '@elastic/eui';
import { groupBy, isEmpty } from 'lodash';
import { i18n } from '@kbn/i18n';
Expand Down Expand Up @@ -94,6 +96,7 @@ export type ModelItem = TrainedModelConfigResponse & {
*/
stateDescription?: string;
recommended?: boolean;
supported: boolean;
/**
* Model name, e.g. elser
*/
Expand Down Expand Up @@ -129,6 +132,7 @@ export const getDefaultModelsListState = (): ListingPageUrlState => ({
pageSize: 10,
sortField: modelIdColumnName,
sortDirection: 'asc',
showAll: false,
});

interface Props {
Expand Down Expand Up @@ -286,9 +290,13 @@ export const ModelsList: FC<Props> = ({
);
const forDownload = await trainedModelsApiService.getTrainedModelDownloads();
const notDownloaded: ModelItem[] = forDownload
.filter(({ model_id: modelId, hidden, recommended }) => {
if (recommended && idMap.has(modelId)) {
idMap.get(modelId)!.recommended = true;
.filter(({ model_id: modelId, hidden, recommended, supported }) => {
if (idMap.has(modelId)) {
const model = idMap.get(modelId)!;
if (recommended) {
model.recommended = true;
}
model.supported = supported;
}
return !idMap.has(modelId) && !hidden;
})
Expand All @@ -306,6 +314,7 @@ export const ModelsList: FC<Props> = ({
arch: modelDefinition.arch,
softwareLicense: modelDefinition.license,
licenseUrl: modelDefinition.licenseUrl,
supported: modelDefinition.supported,
} as ModelItem;
});
resultItems = [...resultItems, ...notDownloaded];
Expand Down Expand Up @@ -530,12 +539,6 @@ export const ModelsList: FC<Props> = ({
try {
setIsLoading(true);
await trainedModelsApiService.installElasticTrainedModelConfig(modelId);
displaySuccessToast(
i18n.translate('xpack.ml.trainedModels.modelsList.downloadSuccess', {
defaultMessage: '"{modelId}" model download has been started successfully.',
values: { modelId },
})
);
// Need to fetch model state updates
await fetchModelsData();
} catch (e) {
Expand All @@ -549,7 +552,7 @@ export const ModelsList: FC<Props> = ({
setIsLoading(true);
}
},
[displayErrorToast, displaySuccessToast, fetchModelsData, trainedModelsApiService]
[displayErrorToast, fetchModelsData, trainedModelsApiService]
);

/**
Expand Down Expand Up @@ -633,26 +636,28 @@ export const ModelsList: FC<Props> = ({
}),
truncateText: false,
'data-test-subj': 'mlModelsTableColumnDescription',
render: ({ description, recommended }: ModelItem) => {
render: ({ description, recommended, tags, supported }: ModelItem) => {
if (!description) return null;
const descriptionText = description.replace('(Tech Preview)', '');
return recommended ? (
<EuiToolTip
content={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.recommendedDownloadContent"
defaultMessage="Recommended model version for your cluster's hardware configuration"
/>
}
>

const tooltipContent =
supported === false ? (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.notSupportedDownloadContent"
defaultMessage="Model version is not supported by your cluster's hardware configuration"
/>
) : recommended === false ? (
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.notRecommendedDownloadContent"
defaultMessage="Model version is not optimized for your cluster's hardware configuration"
/>
) : null;

return tooltipContent ? (
<EuiToolTip content={tooltipContent}>
<>
{descriptionText}&nbsp;
<b>
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.recommendedDownloadLabel"
defaultMessage="(Recommended)"
/>
</b>
<EuiIcon type={'warning'} color="warning" />
</>
</EuiToolTip>
) : (
Expand Down Expand Up @@ -861,15 +866,39 @@ export const ModelsList: FC<Props> = ({
const isElserCalloutVisible =
!isElserCalloutDismissed && items.findIndex((i) => i.model_id === ELSER_ID_V1) >= 0;

const tableItems = useMemo(() => {
if (pageState.showAll) {
return items;
} else {
return items.filter((item) => item.supported !== false);
}
}, [items, pageState.showAll]);

if (!isInitialized) return null;

return (
<>
<SavedObjectsWarning onCloseFlyout={fetchModelsData} forceRefresh={isLoading} />
<EuiFlexGroup justifyContent="spaceBetween">
{modelsStats ? (
<EuiFlexItem grow={false}>
<StatsBar stats={modelsStats} dataTestSub={'mlInferenceModelsStatsBar'} />
<EuiFlexItem>
<EuiFlexGroup alignItems="center">
<EuiFlexItem grow={false}>
<StatsBar stats={modelsStats} dataTestSub={'mlInferenceModelsStatsBar'} />
</EuiFlexItem>
<EuiFlexItem grow={false}>
<EuiSwitch
label={
<FormattedMessage
id="xpack.ml.trainedModels.modelsList.showAllLabel"
defaultMessage="Show all"
/>
}
checked={!!pageState.showAll}
onChange={(e) => updatePageState({ showAll: e.target.checked })}
/>
</EuiFlexItem>
</EuiFlexGroup>
</EuiFlexItem>
) : null}
<EuiFlexItem grow={false}>
Expand All @@ -894,7 +923,7 @@ export const ModelsList: FC<Props> = ({
allowNeutralSort={false}
columns={columns}
itemIdToExpandedRowMap={itemIdToExpandedRowMap}
items={items}
items={tableItems}
itemId={ModelsTableToConfigMapping.id}
loading={isLoading}
search={search}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ describe('modelsProvider', () => {
config: { input: { field_names: ['text_field'] } },
description: 'Elastic Learned Sparse EncodeR v1 (Tech Preview)',
hidden: true,
supported: false,
model_id: '.elser_model_1',
version: 1,
modelName: 'elser',
Expand All @@ -66,6 +67,7 @@ describe('modelsProvider', () => {
{
config: { input: { field_names: ['text_field'] } },
default: true,
supported: true,
description: 'Elastic Learned Sparse EncodeR v2',
model_id: '.elser_model_2',
version: 2,
Expand All @@ -79,6 +81,7 @@ describe('modelsProvider', () => {
model_id: '.elser_model_2_linux-x86_64',
os: 'Linux',
recommended: true,
supported: true,
version: 2,
modelName: 'elser',
type: ['elastic', 'pytorch', 'text_expansion'],
Expand All @@ -88,6 +91,7 @@ describe('modelsProvider', () => {
description: 'E5 (EmbEddings from bidirEctional Encoder rEpresentations)',
model_id: '.multilingual-e5-small',
default: true,
supported: true,
version: 1,
modelName: 'e5',
license: 'MIT',
Expand All @@ -102,6 +106,7 @@ describe('modelsProvider', () => {
model_id: '.multilingual-e5-small_linux-x86_64',
os: 'Linux',
recommended: true,
supported: true,
version: 1,
modelName: 'e5',
license: 'MIT',
Expand Down Expand Up @@ -140,6 +145,7 @@ describe('modelsProvider', () => {
config: { input: { field_names: ['text_field'] } },
description: 'Elastic Learned Sparse EncodeR v1 (Tech Preview)',
hidden: true,
supported: false,
model_id: '.elser_model_1',
version: 1,
modelName: 'elser',
Expand All @@ -148,6 +154,7 @@ describe('modelsProvider', () => {
{
config: { input: { field_names: ['text_field'] } },
recommended: true,
supported: true,
description: 'Elastic Learned Sparse EncodeR v2',
model_id: '.elser_model_2',
version: 2,
Expand All @@ -163,12 +170,14 @@ describe('modelsProvider', () => {
version: 2,
modelName: 'elser',
type: ['elastic', 'pytorch', 'text_expansion'],
supported: false,
},
{
config: { input: { field_names: ['text_field'] } },
description: 'E5 (EmbEddings from bidirEctional Encoder rEpresentations)',
model_id: '.multilingual-e5-small',
recommended: true,
supported: true,
version: 1,
modelName: 'e5',
type: ['pytorch', 'text_embedding'],
Expand All @@ -182,6 +191,7 @@ describe('modelsProvider', () => {
'E5 (EmbEddings from bidirEctional Encoder rEpresentations), optimized for linux-x86_64',
model_id: '.multilingual-e5-small_linux-x86_64',
os: 'Linux',
supported: false,
version: 1,
modelName: 'e5',
type: ['pytorch', 'text_embedding'],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ export class ModelsProvider {
const modelDefinitionResponse = {
...def,
...(recommended ? { recommended } : {}),
supported: !!def.default || recommended,
model_id: modelId,
};

Expand Down
2 changes: 0 additions & 2 deletions x-pack/plugins/translations/translations/fr-FR.json
Original file line number Diff line number Diff line change
Expand Up @@ -28304,7 +28304,6 @@
"xpack.ml.trainedModels.modelsList.disableSelectableMessage": "Le modèle a des pipelines associés",
"xpack.ml.trainedModels.modelsList.downloadFailed": "Échec du téléchargement de \"{modelId}\"",
"xpack.ml.trainedModels.modelsList.downloadStatusCheckErrorMessage": "Échec de la vérification du statut du téléchargement",
"xpack.ml.trainedModels.modelsList.downloadSuccess": "Le téléchargement du modèle \"{modelId}\" a bien été démarré.",
"xpack.ml.trainedModels.modelsList.e5Title": "E5 (EmbEddings from bidirEctional Encoder rEpresentations)",
"xpack.ml.trainedModels.modelsList.e5v1Description": "E5 (EmbEddings from bidirEctional Encoder rEpresentations)",
"xpack.ml.trainedModels.modelsList.e5v1x86Description": "E5 (EmbEddings from bidirEctional Encoder rEpresentations), optimisé for linux-x86_64",
Expand Down Expand Up @@ -28362,7 +28361,6 @@
"xpack.ml.trainedModels.modelsList.pipelines.processorStats.timePerDocHeader": "Temps par document",
"xpack.ml.trainedModels.modelsList.pipelines.processorStats.typeHeader": "Type de processeur",
"xpack.ml.trainedModels.modelsList.recommendedDownloadContent": "Version du modèle recommandée pour la configuration matérielle de votre cluster",
"xpack.ml.trainedModels.modelsList.recommendedDownloadLabel": "(Recommandée)",
"xpack.ml.trainedModels.modelsList.selectableMessage": "Sélectionner un modèle",
"xpack.ml.trainedModels.modelsList.selectedModelsMessage": "{modelsCount, plural, one{# modèle sélectionné} other {# modèles sélectionnés}}",
"xpack.ml.trainedModels.modelsList.startDeployment.cancelButton": "Annuler",
Expand Down
2 changes: 0 additions & 2 deletions x-pack/plugins/translations/translations/ja-JP.json
Original file line number Diff line number Diff line change
Expand Up @@ -28209,7 +28209,6 @@
"xpack.ml.trainedModels.modelsList.disableSelectableMessage": "モデルにはパイプラインが関連付けられています",
"xpack.ml.trainedModels.modelsList.downloadFailed": "\"{modelId}\"をダウンロードできませんでした",
"xpack.ml.trainedModels.modelsList.downloadStatusCheckErrorMessage": "ダウンロードステータスを確認できませんでした",
"xpack.ml.trainedModels.modelsList.downloadSuccess": "\"{modelId}\"モデルのダウンロードが正常に開始しました。",
"xpack.ml.trainedModels.modelsList.e5Title": "E5(bidirEctional Encoder rEpresentationsからのEmbEddings)",
"xpack.ml.trainedModels.modelsList.e5v1Description": "E5(bidirEctional Encoder rEpresentationsからのEmbEddings)",
"xpack.ml.trainedModels.modelsList.e5v1x86Description": "E5(bidirEctional Encoder rEpresentationsからのEmbEddings)、inux-x86_64向けに最適化",
Expand Down Expand Up @@ -28267,7 +28266,6 @@
"xpack.ml.trainedModels.modelsList.pipelines.processorStats.timePerDocHeader": "ドキュメントごとの時間",
"xpack.ml.trainedModels.modelsList.pipelines.processorStats.typeHeader": "プロセッサータイプ",
"xpack.ml.trainedModels.modelsList.recommendedDownloadContent": "クラスターのハードウェア構成に応じた推奨モデルバージョン",
"xpack.ml.trainedModels.modelsList.recommendedDownloadLabel": "(推奨)",
"xpack.ml.trainedModels.modelsList.selectableMessage": "モデルを選択",
"xpack.ml.trainedModels.modelsList.selectedModelsMessage": "{modelsCount, plural, other {#個のモデル}}が選択されました",
"xpack.ml.trainedModels.modelsList.startDeployment.cancelButton": "キャンセル",
Expand Down
2 changes: 0 additions & 2 deletions x-pack/plugins/translations/translations/zh-CN.json
Original file line number Diff line number Diff line change
Expand Up @@ -28344,7 +28344,6 @@
"xpack.ml.trainedModels.modelsList.disableSelectableMessage": "模型有关联的管道",
"xpack.ml.trainedModels.modelsList.downloadFailed": "无法下载“{modelId}”",
"xpack.ml.trainedModels.modelsList.downloadStatusCheckErrorMessage": "无法检查下载状态",
"xpack.ml.trainedModels.modelsList.downloadSuccess": "已成功启动“{modelId}”模型下载。",
"xpack.ml.trainedModels.modelsList.e5Title": "E5 (EmbEddings from bidirEctional Encoder rEpresentations)",
"xpack.ml.trainedModels.modelsList.e5v1Description": "E5 (EmbEddings from bidirEctional Encoder rEpresentations)",
"xpack.ml.trainedModels.modelsList.e5v1x86Description": "针对 linux-x86_64 进行了优化的 E5 (EmbEddings from bidirEctional Encoder rEpresentations)",
Expand Down Expand Up @@ -28402,7 +28401,6 @@
"xpack.ml.trainedModels.modelsList.pipelines.processorStats.timePerDocHeader": "每个文档的时间",
"xpack.ml.trainedModels.modelsList.pipelines.processorStats.typeHeader": "处理器类型",
"xpack.ml.trainedModels.modelsList.recommendedDownloadContent": "为您集群的硬件配置推荐的模型版本",
"xpack.ml.trainedModels.modelsList.recommendedDownloadLabel": "(推荐)",
"xpack.ml.trainedModels.modelsList.selectableMessage": "选择模型",
"xpack.ml.trainedModels.modelsList.selectedModelsMessage": "{modelsCount, plural, other {# 个模型}}已选择",
"xpack.ml.trainedModels.modelsList.startDeployment.cancelButton": "取消",
Expand Down
Loading

0 comments on commit 514db54

Please sign in to comment.