From a06a9d8d6f1a32be9bb4ca3da4cb97a9c3771cda Mon Sep 17 00:00:00 2001 From: Robert Oskamp Date: Fri, 8 Jul 2022 15:31:42 +0200 Subject: [PATCH 1/2] [ML] Functional tests - support tiny trained models --- .../plugins/ml/common/types/trained_models.ts | 1 + .../model_management/model_list.ts | 35 +++++- x-pack/test/functional/services/ml/api.ts | 102 +++++++++++++----- x-pack/test/functional/services/ml/index.ts | 2 +- .../pt_tiny_fill_mask/config.json | 7 ++ .../pt_tiny_fill_mask/traced_pytorch_model.pt | Bin 0 -> 1739 bytes .../pt_tiny_fill_mask/vocabulary.json | 16 +++ .../pt_tiny_ner/config.json | 7 ++ .../pt_tiny_ner/traced_pytorch_model.pt | Bin 0 -> 1675 bytes .../pt_tiny_ner/vocabulary.json | 21 ++++ .../pt_tiny_pass_through/config.json | 13 +++ .../traced_pytorch_model.pt | Bin 0 -> 1630 bytes .../pt_tiny_pass_through/vocabulary.json | 5 + .../pt_tiny_text_classification/config.json | 9 ++ .../traced_pytorch_model.pt | Bin 0 -> 1606 bytes .../vocabulary.json | 16 +++ .../pt_tiny_text_embedding/config.json | 7 ++ .../traced_pytorch_model.pt | Bin 0 -> 1517 bytes .../pt_tiny_text_embedding/vocabulary.json | 16 +++ .../pt_tiny_zero_shot/config.json | 9 ++ .../pt_tiny_zero_shot/traced_pytorch_model.pt | Bin 0 -> 1492 bytes .../pt_tiny_zero_shot/vocabulary.json | 24 +++++ .../functional/services/ml/trained_models.ts | 17 +-- 23 files changed, 263 insertions(+), 44 deletions(-) create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_fill_mask/config.json create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_fill_mask/traced_pytorch_model.pt create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_fill_mask/vocabulary.json create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_ner/config.json create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_ner/traced_pytorch_model.pt create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_ner/vocabulary.json create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_pass_through/config.json create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_pass_through/traced_pytorch_model.pt create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_pass_through/vocabulary.json create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_classification/config.json create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_classification/traced_pytorch_model.pt create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_classification/vocabulary.json create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_embedding/config.json create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_embedding/traced_pytorch_model.pt create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_embedding/vocabulary.json create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_zero_shot/config.json create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_zero_shot/traced_pytorch_model.pt create mode 100644 x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_zero_shot/vocabulary.json diff --git a/x-pack/plugins/ml/common/types/trained_models.ts b/x-pack/plugins/ml/common/types/trained_models.ts index 8798ed5ccb5a2..2dea2eecd5e5a 100644 --- a/x-pack/plugins/ml/common/types/trained_models.ts +++ b/x-pack/plugins/ml/common/types/trained_models.ts @@ -66,6 +66,7 @@ export type PutTrainedModelConfig = { model_aliases?: string[]; } & Record; tags?: string[]; + model_type?: TrainedModelType; inference_config?: Record; input: { field_names: string[] }; } & XOR< diff --git a/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts b/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts index ca360130b89f9..688ab3d6ced80 100644 --- a/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts +++ b/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts @@ -6,14 +6,33 @@ */ import { FtrProviderContext } from '../../../../ftr_provider_context'; +import { TrainedModelName } from '../../../../services/ml/api'; export default function ({ getService }: FtrProviderContext) { const ml = getService('ml'); + const tinyTrainedModels = [ + 'fill_mask', + 'ner', + 'pass_through', + 'text_classification', + 'text_embedding', + 'zero_shot', + ].map((type) => ({ + id: `tiny_${type}`, + name: `pt_tiny_${type}` as TrainedModelName, + desription: `Tiny/Dummy PyTorch model (${type})`, + modelTypes: ['pytorch', type], + })); + describe('trained models', function () { before(async () => { - await ml.trainedModels.createTestTrainedModels('classification', 15, true); - await ml.trainedModels.createTestTrainedModels('regression', 15); + for (const model of tinyTrainedModels) { + await ml.api.importTrainedModel(model.id, model.name); + } + + await ml.api.createTestTrainedModels('classification', 15, true); + await ml.api.createTestTrainedModels('regression', 15); }); after(async () => { @@ -56,7 +75,7 @@ export default function ({ getService }: FtrProviderContext) { 'should display the stats bar with the total number of models' ); // +1 because of the built-in model - await ml.trainedModels.assertStats(31); + await ml.trainedModels.assertStats(37); await ml.testExecution.logTestStep('should display the table'); await ml.trainedModels.assertTableExists(); @@ -81,6 +100,16 @@ export default function ({ getService }: FtrProviderContext) { await ml.trainedModelsTable.assertPipelinesTabContent(false); }); + for (const model of tinyTrainedModels) { + it(`renders expanded row content correctly for imported tiny model ${model.id} without pipelines`, async () => { + await ml.trainedModelsTable.ensureRowIsExpanded(model.id); + await ml.trainedModelsTable.assertDetailsTabContent(); + await ml.trainedModelsTable.assertInferenceConfigTabContent(); + await ml.trainedModelsTable.assertStatsTabContent(); + await ml.trainedModelsTable.assertPipelinesTabContent(false); + }); + } + it('displays the built-in model and no actions are enabled', async () => { await ml.testExecution.logTestStep('should display the model in the table'); await ml.trainedModelsTable.filterWithSearchString(builtInModelData.modelId, 1); diff --git a/x-pack/test/functional/services/ml/api.ts b/x-pack/test/functional/services/ml/api.ts index 55d5a978dae82..71dac67f4e251 100644 --- a/x-pack/test/functional/services/ml/api.ts +++ b/x-pack/test/functional/services/ml/api.ts @@ -9,7 +9,6 @@ import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey'; import expect from '@kbn/expect'; import { ProvidedType } from '@kbn/test'; import fs from 'fs'; -import path from 'path'; import { Calendar } from '@kbn/ml-plugin/server/models/calendar'; import { Annotation } from '@kbn/ml-plugin/common/types/annotations'; import { DataFrameAnalyticsConfig } from '@kbn/ml-plugin/public/application/data_frame_analytics/common'; @@ -29,6 +28,16 @@ import { FtrProviderContext } from '../../ftr_provider_context'; export type MlApi = ProvidedType; type ModelType = 'regression' | 'classification'; +export type TrainedModelName = + | 'pt_tiny_fill_mask' + | 'pt_tiny_ner' + | 'pt_tiny_pass_through' + | 'pt_tiny_text_classification' + | 'pt_tiny_text_embedding' + | 'pt_tiny_zero_shot'; +export interface TrainedModelVocabulary { + vocabulary: string[]; +} export function MachineLearningAPIProvider({ getService }: FtrProviderContext) { const es = getService('es'); @@ -1135,6 +1144,38 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) { return model; }, + async createTrainedModelVocabularyES(modelId: string, body: TrainedModelVocabulary) { + log.debug(`Creating vocabulary for trained model "${modelId}"`); + const { body: responseBody, status } = await esSupertest + .put(`/_ml/trained_models/${modelId}/vocabulary`) + .send(body); + this.assertResponseStatusCode(200, status, responseBody); + + log.debug('> Trained model vocabulary created'); + }, + + /** + * For the purpose of the functional tests where we only deal with very + * small models, we assume that the model definition can be uploaded as + * one part. + */ + async uploadTrainedModelDefinitionES(modelId: string, modelDefinitionPath: string) { + log.debug(`Uploading definition for trained model "${modelId}"`); + + const body = { + total_definition_length: fs.statSync(modelDefinitionPath).size, + definition: fs.readFileSync(modelDefinitionPath).toString('base64'), + total_parts: 1, + }; + + const { body: responseBody, status } = await esSupertest + .put(`/_ml/trained_models/${modelId}/definition/0`) + .send(body); + this.assertResponseStatusCode(200, status, responseBody); + + log.debug('> Trained model definition uploaded'); + }, + async deleteTrainedModelES(modelId: string) { log.debug(`Creating trained model with id "${modelId}"`); const { body: model, status } = await esSupertest.delete(`/_ml/trained_models/${modelId}`); @@ -1149,24 +1190,9 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) { count: number = 10, withIngestPipelines = false ) { - const compressedDefinition = this.getCompressedModelDefinition(modelType); + const modelIds = new Array(count).fill(null).map((_v, i) => `dfa_${modelType}_model_n_${i}`); - const modelIds = new Array(count).fill(null).map((v, i) => `dfa_${modelType}_model_n_${i}`); - - const models = modelIds.map((id) => { - return { - model_id: id, - body: { - compressed_definition: compressedDefinition, - inference_config: { - [modelType]: {}, - }, - input: { - field_names: ['common_field'], - }, - } as PutTrainedModelConfig, - }; - }); + const models = modelIds.map((id) => this.createTestTrainedModelConfig(id, modelType)); for (const model of models) { await this.createTrainedModel(model.model_id, model.body); @@ -1178,7 +1204,7 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) { return modelIds; }, - async createTestTrainedModelConfig(modelId: string, modelType: ModelType) { + createTestTrainedModelConfig(modelId: string, modelType: ModelType) { const compressedDefinition = this.getCompressedModelDefinition(modelType); return { @@ -1201,16 +1227,44 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) { */ getCompressedModelDefinition(modelType: ModelType) { return fs.readFileSync( - path.resolve( - __dirname, - 'resources', - 'trained_model_definitions', - `minimum_valid_config_${modelType}.json.gz.b64` + require.resolve( + `./resources/trained_model_definitions/minimum_valid_config_${modelType}.json.gz.b64` ), 'utf-8' ); }, + getTrainedModelConfig(modelName: TrainedModelName) { + const configFileContent = fs.readFileSync( + require.resolve(`./resources/trained_model_definitions/${modelName}/config.json`), + 'utf-8' + ); + return JSON.parse(configFileContent) as PutTrainedModelConfig; + }, + + getTrainedModelVocabulary(modelName: TrainedModelName) { + const vocabularyFileContent = fs.readFileSync( + require.resolve(`./resources/trained_model_definitions/${modelName}/vocabulary.json`), + 'utf-8' + ); + return JSON.parse(vocabularyFileContent) as TrainedModelVocabulary; + }, + + getTrainedModelDefinitionPath(modelName: TrainedModelName) { + return require.resolve( + `./resources/trained_model_definitions/${modelName}/traced_pytorch_model.pt` + ); + }, + + async importTrainedModel(modelId: string, modelName: TrainedModelName) { + await this.createTrainedModel(modelId, this.getTrainedModelConfig(modelName)); + await this.createTrainedModelVocabularyES(modelId, this.getTrainedModelVocabulary(modelName)); + await this.uploadTrainedModelDefinitionES( + modelId, + this.getTrainedModelDefinitionPath(modelName) + ); + }, + async createModelAlias(modelId: string, modelAlias: string) { log.debug(`Creating alias for model "${modelId}"`); const { body, status } = await esSupertest.put( diff --git a/x-pack/test/functional/services/ml/index.ts b/x-pack/test/functional/services/ml/index.ts index ae7cb38e1c695..561bf4f6026b5 100644 --- a/x-pack/test/functional/services/ml/index.ts +++ b/x-pack/test/functional/services/ml/index.ts @@ -124,7 +124,7 @@ export function MachineLearningProvider(context: FtrProviderContext) { const testResources = MachineLearningTestResourcesProvider(context, api); const alerting = MachineLearningAlertingProvider(context, api, commonUI); const swimLane = SwimLaneProvider(context); - const trainedModels = TrainedModelsProvider(context, api, commonUI); + const trainedModels = TrainedModelsProvider(context, commonUI); const trainedModelsTable = TrainedModelsTableProvider(context, commonUI); const mlNodesPanel = MlNodesPanelProvider(context); diff --git a/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_fill_mask/config.json b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_fill_mask/config.json new file mode 100644 index 0000000000000..cf00e1ade82ed --- /dev/null +++ b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_fill_mask/config.json @@ -0,0 +1,7 @@ +{ + "description": "Tiny/Dummy PyTorch model (fill_mask)", + "model_type": "pytorch", + "inference_config": { + "fill_mask": {} + } +} \ No newline at end of file diff --git a/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_fill_mask/traced_pytorch_model.pt b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_fill_mask/traced_pytorch_model.pt new file mode 100644 index 0000000000000000000000000000000000000000..6e7648e0ca77a621f9e456afd04337c2c01873ef GIT binary patch literal 1739 zcmWIWW@cev;NW1u0J03M3?)U0$*C#v1(hZFMadcQx%nxnIr=GyC5d_k**R`bj0{l? zOv&-_5ZQRHkj%VFx6GUz-^AiJb&b{b4<;(T( zNzNJV6F%PiE%&}~a;Q>)(Ys%HYizY5S1F&r7w%t70XMbOU|3x zrfyLYSKrMMdg~iezSrC^Wh|Bj|aej z>}SLmm<9|^L?6^efM=ov^_0Q{X}# z!6#Ziv(D;!ef%MzEb`;Q1h$ozxR+ICDEL>;eGtwOdCg4l4o|c91h<#5-J!)3I%B78 z_!K)OOylII*P;c|Jx#@`iQ+m_1$)2FzY*@O)|kXpUvT%K`HF+rqJG(ZHR(2)Y-1z) z#Z14uXGimd?j(~JW!9nDHIMh}?lW58dnWW;iLCyVd9x?^D(c-Z5f=JoX`tIW=c~5Z z^{^*uc9(55XKp*aBtWJ{u`6A{(U`kav?#XkRO(iDznH6P>z|wze4}yrW}21ooUJFV ze-$;lw)-*G-?C9+1PZTI^>vj;d)u2EEqi<=xY$T1WHLj0My40qBGe*Bm)CTI>4J1 N%wq*sQ6TjYwE*ycCZqrW literal 0 HcmV?d00001 diff --git a/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_fill_mask/vocabulary.json b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_fill_mask/vocabulary.json new file mode 100644 index 0000000000000..9dc30191a0ca4 --- /dev/null +++ b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_fill_mask/vocabulary.json @@ -0,0 +1,16 @@ +{ + "vocabulary": [ + "[PAD]", + "[UNK]", + "[CLS]", + "[SEP]", + "[MASK]", + "Hello", + "world", + "car", + "bike", + "bee", + "bird", + "and" + ] +} \ No newline at end of file diff --git a/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_ner/config.json b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_ner/config.json new file mode 100644 index 0000000000000..2d50493dbc938 --- /dev/null +++ b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_ner/config.json @@ -0,0 +1,7 @@ +{ + "description": "Tiny/Dummy PyTorch model (ner)", + "model_type": "pytorch", + "inference_config": { + "ner": {} + } +} \ No newline at end of file diff --git a/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_ner/traced_pytorch_model.pt b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_ner/traced_pytorch_model.pt new file mode 100644 index 0000000000000000000000000000000000000000..35e36941c2e7960f159a672e45563eda2dee7db6 GIT binary patch literal 1675 zcmWIWW@cev;NW1u0J03M3?)U0$*C#v1(hZFMadcQx%nxnIr=GyC5d_k**R`bj0{l? zOv&-_5ZQRHkj%VFztkeGLI%yoT8#*hp+FNd^D^_&3mH2j#DM(x%;Na8(wv<5q{QUx z^2DN)_>BDg>_R5L(xgIWy#Q}^j$7Y~*7*Tdf^Y!Pbs|95VLDHZK@Pj?l7Vj2hdW8H zpwi7riXjRoXq|92>#%`9%lpbM+glIVE-zNBZC*D0PETY?_Eiq;sD<)gzrIXO5q_cg zWQXy6gTs1mEezth@6$h>RX?4wLY9xA!>^)nv4@?)A=KrF=&7x^K2`9aR7TgeD#+hb>c>EW#CKC^Xo$~q?-_XLPI9k7{| zz4?*b7o8`c9nF7<22I`Sy>!Fl7i?3mKD+Sfi%`g`YpOXl*OyL?`uHc2*~V@13omU8 z7Vq5DF<=gb9U~{-MW0oD~*3Fpx80ezS3j?j2UevNbF$Ghe`|K6mE?1v399R%AfG{Z6VRwrRgAsPOl%*CGXXfX*Iq8w; z3S%x%NL-wk9tsQu1`rPLW&~02G6}gnkOfJg0K6p>#3)1of}9=HP>h)evJ8$WZb$~YL?3RSUO}aslP!{C;I1N`QQBvieb_*t_4}VL<;gtj3NJey ze9(7ymg-F>fv#FF)=7SImWX}79v+dK*`9N${Qcv-Cs#>4oWT6);EsgAJ(uR_njL0y zIJF?c(&t2=m&C)8xRbL@reKE=8lZ70b5*b_Ilp@o0;t` za^S~=$tJ}z4?66G!q@zmCHqUI-2BI3-hV$1OYf-uSI3v_@-5Y<^#-fiy4&gHORB_P z@6em&FS5Tl>c#A3Ro{=U7W^b}^yCkfl>1vYSZ_C2uKM7|m8MYcrzY>bmj%?A3-?`J zCBJd5(>*b534K|sN2xV|W=Xk;J~L}IZ`Uw_Lf@A~qV^Lo=sw?a|B5ExdnFQomxts*D?Fj26ok-kbKIv!p@ndf>kjy*oYZ zo%1u|&ogB{T=(tLm&#xJ_83bn_tg9C+Ouv_jn7K)wL-ss&)D9f+@F49{qO#@ z+86iNRr$V(SHCe$Rk*}*A-AKN@0}2T)9w6MzFeAjS^Kgq`NKldaq=JFH%(T zl6O{`{=nt$oypho-f{?i2zl!h-*D>gws(G^?@pikea+5ck(AY|=MIaG_s!cLdR}%< zKxjC7W7yD6B##|upT%4C4 z$^^6-gaf=8K@_}9LM|ypKoTebTL}d*0a2nLr*>Hs-7P>KvI!XF3%XIr;jWHi6cfJ?PFIk5rV4g#(HhIxk#cwE2FX;D8PTF*1`&WJlI=AXNr z*`dentkX;r#Y2L>5*H^IzT-cYWl`{>fzM4aUe)hK*LRCqng2|(J}Mcl|NK27y664s zh|SYvEH1t>>D=@BT9b}dtN{OZzPsxq8TdA)q|C0%wF~xGboiq$Q|wwDW7aeGC%#)1 zD0KZ{Y+<&a$7zrMvvh524vUsVduWRDZunasAe?jA{@ugOjB&83mP~%O&n!1B@kP<3N%4gXudREuHg+QS(;XsT3!W~?{PjqvW$Lq8b0!vRD%`i4 zvxhZsXK&O08^^r8qU6-et-k%8xAwt1wtpqM;nC8D53adQm~R-|ZDxJ)4PeZ_m8T{@dDZkIu7wSjY2K@_X^N{z$2aM>kJT+3Qx; za5FK%`24+PZ)`UdK0VH2xmRN!Faff2h;*Mn^aU8e$Dsj?Jr5}|*y0G_yyB9?ypm!_ zMzTZ+Qb@oONj5NPf-^|2K4&;EBQt<7D1%^kk2HfZR`-;p78L{YsGE~MsV*_*0)@uK zdFi1{K$}50z?%_7!OIlntSbYOKmqtl7>G%Ttd5+hRZ&dY4CG-m2_xg98;Kl^`Y1*| z21YDVMk4YCy3dj0L=nYs7iL_BBc%&;4Ki4J5`2gdp`0wEzcX?PdS~ literal 0 HcmV?d00001 diff --git a/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_classification/vocabulary.json b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_classification/vocabulary.json new file mode 100644 index 0000000000000..9dc30191a0ca4 --- /dev/null +++ b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_classification/vocabulary.json @@ -0,0 +1,16 @@ +{ + "vocabulary": [ + "[PAD]", + "[UNK]", + "[CLS]", + "[SEP]", + "[MASK]", + "Hello", + "world", + "car", + "bike", + "bee", + "bird", + "and" + ] +} \ No newline at end of file diff --git a/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_embedding/config.json b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_embedding/config.json new file mode 100644 index 0000000000000..9c5c226cfd232 --- /dev/null +++ b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_embedding/config.json @@ -0,0 +1,7 @@ +{ + "description": "Tiny/Dummy PyTorch model (text_embedding)", + "model_type": "pytorch", + "inference_config": { + "text_embedding": {} + } +} \ No newline at end of file diff --git a/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_embedding/traced_pytorch_model.pt b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_embedding/traced_pytorch_model.pt new file mode 100644 index 0000000000000000000000000000000000000000..c7d7686060915a73c306c01abce9ff8e02ef6c11 GIT binary patch literal 1517 zcmWIWW@cev;NW1u0OAbX3?-=*CGn}bNvSC*nR)5@DTyVCdIi}zZcgkBQ4r9;lpG&l zl3$dZ5g*SLTw0J?6r7n`kdqn$)ajg)SX`W$mYH0VnV-j1$e`I+s}TXRsiY_|GY@EB zA!BER7?2;ISsb5Mnv)Zsl$e}do>-I;pOK%RUC88DnpDWF7vRm#al?vjOEOR;2nPV2 zECO^grkfQQ_|cu5oS%}a4|kznL8Y4$H_YuIAGG!x<{dWRas580X)*5-#-Ih=jU_Md zlzEsPR!Eth#c_I(gJ(wd@2GQ^cs#Ei|Nisj!I=^dXE48<`24{O7yFlM9k$u1h{f+c zJwx{DSvCek2~%U$OC22S8M~)`yLn4{%9Rc2|5>}zLbMqV#~Xp2SB1t=u>(4q&zL#8H^YEpp-hxPHFU{5MU-!?IvJZP($F2YS`MP_57(rq7 zN@zvobzpeSVZ;|^S`2~&!%QzFHK{Z`9vt($$e|Y%HFd&yufq-k$I35@nQ&Y(zmd$P z7;w=0pn61@z8sU1;x6IM5#Q@C-rMsw)-~XM=$(knqRoFLW*pAnH}he&pHY{WNNPa0 z1{d4qbw`sJrw8{4ygaf{v3FDd)~7R@QzuWJ^2gI-@29B`@*MTHh#zpH?oD zPj1?%FaJC_GHX(i6!Yf~Pt3D-PAG4Un=Lq(#dz8MHw^4u%+&{T=RecVJN`ysv&klv z4cir~-`@~SJ3g_O>9@qD^Cx~yjc0yxcF7NL$bNC1;{OpCvIn3ci#-F$FtB2VY+i9m zVqQrxBo{F@kqGDk7a&t04jTmVDBt#i= zfs)F_dFi1{K$}50z?%_7!OH~XtS1hVKmj;P2#6twT#B3)6;KS=4CJ91f{|y@O+yZU zEfmw917n1cX^7N}?p5SSkU=rmkr|7*NO6g75^{ioN;m|Vf@u;o3In{^*mR%@<(PHh pT3Deh7+neUBM1nA(h3j&^|6EKZm0r~bbvQ28%T^52tn#0Y5}4K%?AJg literal 0 HcmV?d00001 diff --git a/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_embedding/vocabulary.json b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_embedding/vocabulary.json new file mode 100644 index 0000000000000..9dc30191a0ca4 --- /dev/null +++ b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_text_embedding/vocabulary.json @@ -0,0 +1,16 @@ +{ + "vocabulary": [ + "[PAD]", + "[UNK]", + "[CLS]", + "[SEP]", + "[MASK]", + "Hello", + "world", + "car", + "bike", + "bee", + "bird", + "and" + ] +} \ No newline at end of file diff --git a/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_zero_shot/config.json b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_zero_shot/config.json new file mode 100644 index 0000000000000..ce273c9f19993 --- /dev/null +++ b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_zero_shot/config.json @@ -0,0 +1,9 @@ +{ + "description": "Tiny/Dummy PyTorch model (zero_shot)", + "model_type": "pytorch", + "inference_config": { + "zero_shot_classification": { + "classification_labels": ["entailment", "neutral", "contradiction"] + } + } +} \ No newline at end of file diff --git a/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_zero_shot/traced_pytorch_model.pt b/x-pack/test/functional/services/ml/resources/trained_model_definitions/pt_tiny_zero_shot/traced_pytorch_model.pt new file mode 100644 index 0000000000000000000000000000000000000000..acffbb80eaf53766506eed1b3af62ff7ffe68486 GIT binary patch literal 1492 zcmWIWW@cev;NW1u0747`3{|N``SHaW`6c=(i6x181=%@nPCN`zNT7i!IX=E5zbH8) zKAtPMv>>%8I5W2(Cp8MFGZ?6mtB^snu~s7jWI;($VrE`uUV0&8XM`A#AD>wqpH`Za z6Q7iroL!z+loFqjpPyaG71i3mG4HSekIVNOR(ENS%_+tc{yE-0I5$h8B`0ueLZpZW zr{h1nX`59Xn%0!Ruh?^OrNqLsjZ3zA7kqI)^!CWgUrBRiub$;&Fw_X-R-M$rq5k5R zuk0p6{ve^$yVHehP3`NV8&XUq4Rl&pDp_7$ z!`hiTM}Xh|-|h8+46mL2GD`PY|M0fk^ksSex01y|=bsqHGVOI`c8@-@{lPVviPIY6 zCYjdM9P-WBHX-2Dqr};z0#8e?%6>0+^ZKqRDA4w=3*U7f7+@WY_ySFpK@MM_>7}G5 zm8QpoV_XV%@I_6XaNg^%gTS%!WNs;z_{2WVmY6E-A1tR|9d&k4aWPJ8F8$BGa=F&4 zeOq?~Wg%?X zi~IiWd9GEn>-gSvN3+tIyZ3Dusym*eJ3D9Q;i8GNme}+c?p+W%Rrr*6RkGg=n;DY7 zA1yaAkoqyL%kWsZR@$GQPbQmGK5W;l{(s(X*&E|cg{dACux;lPZ@0K%Y@j@<@923sT>%2JDp zGxPJ@oGh`MjgruDaE!S?{<}CYJ(LM(GYAKGGlD31nSh)Jg+LN00JDUE=ttyKn zG`a%3+1PZT%H)`J;Tl+>EEt^t^b`mPfszCe0QIqh=pv{BkaU1ID;r3R6$nA { const actualStats = await testSubjects.getVisibleText('mlInferenceModelsStatsBar'); From a0631345129027863cf57ae5cc793f1c799440d9 Mon Sep 17 00:00:00 2001 From: Robert Oskamp Date: Mon, 11 Jul 2022 13:55:46 +0200 Subject: [PATCH 2/2] Refactor type for trained model names --- .../model_management/model_list.ts | 21 +++----- x-pack/test/functional/services/ml/api.ts | 51 +++++++++++++++---- 2 files changed, 46 insertions(+), 26 deletions(-) diff --git a/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts b/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts index 688ab3d6ced80..4346ad0815e1c 100644 --- a/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts +++ b/x-pack/test/functional/apps/ml/short_tests/model_management/model_list.ts @@ -6,28 +6,19 @@ */ import { FtrProviderContext } from '../../../../ftr_provider_context'; -import { TrainedModelName } from '../../../../services/ml/api'; +import { SUPPORTED_TRAINED_MODELS } from '../../../../services/ml/api'; export default function ({ getService }: FtrProviderContext) { const ml = getService('ml'); - const tinyTrainedModels = [ - 'fill_mask', - 'ner', - 'pass_through', - 'text_classification', - 'text_embedding', - 'zero_shot', - ].map((type) => ({ - id: `tiny_${type}`, - name: `pt_tiny_${type}` as TrainedModelName, - desription: `Tiny/Dummy PyTorch model (${type})`, - modelTypes: ['pytorch', type], + const trainedModels = Object.values(SUPPORTED_TRAINED_MODELS).map((model) => ({ + ...model, + id: model.name, })); describe('trained models', function () { before(async () => { - for (const model of tinyTrainedModels) { + for (const model of trainedModels) { await ml.api.importTrainedModel(model.id, model.name); } @@ -100,7 +91,7 @@ export default function ({ getService }: FtrProviderContext) { await ml.trainedModelsTable.assertPipelinesTabContent(false); }); - for (const model of tinyTrainedModels) { + for (const model of trainedModels) { it(`renders expanded row content correctly for imported tiny model ${model.id} without pipelines`, async () => { await ml.trainedModelsTable.ensureRowIsExpanded(model.id); await ml.trainedModelsTable.assertDetailsTabContent(); diff --git a/x-pack/test/functional/services/ml/api.ts b/x-pack/test/functional/services/ml/api.ts index 71dac67f4e251..d64f238a66bdd 100644 --- a/x-pack/test/functional/services/ml/api.ts +++ b/x-pack/test/functional/services/ml/api.ts @@ -28,13 +28,42 @@ import { FtrProviderContext } from '../../ftr_provider_context'; export type MlApi = ProvidedType; type ModelType = 'regression' | 'classification'; -export type TrainedModelName = - | 'pt_tiny_fill_mask' - | 'pt_tiny_ner' - | 'pt_tiny_pass_through' - | 'pt_tiny_text_classification' - | 'pt_tiny_text_embedding' - | 'pt_tiny_zero_shot'; + +export const SUPPORTED_TRAINED_MODELS = { + TINY_FILL_MASK: { + name: 'pt_tiny_fill_mask', + description: 'Tiny/Dummy PyTorch model (fill_mask)', + modelTypes: ['pytorch', 'fill_mask'], + }, + TINY_NER: { + name: 'pt_tiny_ner', + description: 'Tiny/Dummy PyTorch model (ner)', + modelTypes: ['pytorch', 'ner'], + }, + TINY_PASS_THROUGH: { + name: 'pt_tiny_pass_through', + description: 'Tiny/Dummy PyTorch model (pass_through)', + modelTypes: ['pytorch', 'pass_through'], + }, + TINY_TEXT_CLASSIFICATION: { + name: 'pt_tiny_text_classification', + description: 'Tiny/Dummy PyTorch model (text_classification)', + modelTypes: ['pytorch', 'text_classification'], + }, + TINY_TEXT_EMBEDDING: { + name: 'pt_tiny_text_embedding', + description: 'Tiny/Dummy PyTorch model (text_embedding)', + modelTypes: ['pytorch', 'text_embedding'], + }, + TINY_ZERO_SHOT: { + name: 'pt_tiny_zero_shot', + description: 'Tiny/Dummy PyTorch model (zero_shot)', + modelTypes: ['pytorch', 'zero_shot'], + }, +} as const; +export type SupportedTrainedModelNamesType = + typeof SUPPORTED_TRAINED_MODELS[keyof typeof SUPPORTED_TRAINED_MODELS]['name']; + export interface TrainedModelVocabulary { vocabulary: string[]; } @@ -1234,7 +1263,7 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) { ); }, - getTrainedModelConfig(modelName: TrainedModelName) { + getTrainedModelConfig(modelName: SupportedTrainedModelNamesType) { const configFileContent = fs.readFileSync( require.resolve(`./resources/trained_model_definitions/${modelName}/config.json`), 'utf-8' @@ -1242,7 +1271,7 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) { return JSON.parse(configFileContent) as PutTrainedModelConfig; }, - getTrainedModelVocabulary(modelName: TrainedModelName) { + getTrainedModelVocabulary(modelName: SupportedTrainedModelNamesType) { const vocabularyFileContent = fs.readFileSync( require.resolve(`./resources/trained_model_definitions/${modelName}/vocabulary.json`), 'utf-8' @@ -1250,13 +1279,13 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) { return JSON.parse(vocabularyFileContent) as TrainedModelVocabulary; }, - getTrainedModelDefinitionPath(modelName: TrainedModelName) { + getTrainedModelDefinitionPath(modelName: SupportedTrainedModelNamesType) { return require.resolve( `./resources/trained_model_definitions/${modelName}/traced_pytorch_model.pt` ); }, - async importTrainedModel(modelId: string, modelName: TrainedModelName) { + async importTrainedModel(modelId: string, modelName: SupportedTrainedModelNamesType) { await this.createTrainedModel(modelId, this.getTrainedModelConfig(modelName)); await this.createTrainedModelVocabularyES(modelId, this.getTrainedModelVocabulary(modelName)); await this.uploadTrainedModelDefinitionES(