Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML] Functional tests - support tiny trained models #136010

Merged
merged 3 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions x-pack/plugins/ml/common/types/trained_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ export type PutTrainedModelConfig = {
model_aliases?: string[];
} & Record<string, unknown>;
tags?: string[];
model_type?: TrainedModelType;
inference_config?: Record<string, unknown>;
input: { field_names: string[] };
} & XOR<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,33 @@
*/

import { FtrProviderContext } from '../../../../ftr_provider_context';
import { TrainedModelName } from '../../../../services/ml/api';

export default function ({ getService }: FtrProviderContext) {
const ml = getService('ml');

const tinyTrainedModels = [
'fill_mask',
'ner',
'pass_through',
'text_classification',
'text_embedding',
'zero_shot',
].map((type) => ({
id: `tiny_${type}`,
name: `pt_tiny_${type}` as TrainedModelName,
desription: `Tiny/Dummy PyTorch model (${type})`,
modelTypes: ['pytorch', type],
}));

describe('trained models', function () {
before(async () => {
await ml.trainedModels.createTestTrainedModels('classification', 15, true);
await ml.trainedModels.createTestTrainedModels('regression', 15);
for (const model of tinyTrainedModels) {
await ml.api.importTrainedModel(model.id, model.name);
}

await ml.api.createTestTrainedModels('classification', 15, true);
await ml.api.createTestTrainedModels('regression', 15);
});

after(async () => {
Expand Down Expand Up @@ -56,7 +75,7 @@ export default function ({ getService }: FtrProviderContext) {
'should display the stats bar with the total number of models'
);
// +1 because of the built-in model
await ml.trainedModels.assertStats(31);
await ml.trainedModels.assertStats(37);

await ml.testExecution.logTestStep('should display the table');
await ml.trainedModels.assertTableExists();
Expand All @@ -81,6 +100,16 @@ export default function ({ getService }: FtrProviderContext) {
await ml.trainedModelsTable.assertPipelinesTabContent(false);
});

for (const model of tinyTrainedModels) {
it(`renders expanded row content correctly for imported tiny model ${model.id} without pipelines`, async () => {
await ml.trainedModelsTable.ensureRowIsExpanded(model.id);
await ml.trainedModelsTable.assertDetailsTabContent();
await ml.trainedModelsTable.assertInferenceConfigTabContent();
await ml.trainedModelsTable.assertStatsTabContent();
await ml.trainedModelsTable.assertPipelinesTabContent(false);
});
}

it('displays the built-in model and no actions are enabled', async () => {
await ml.testExecution.logTestStep('should display the model in the table');
await ml.trainedModelsTable.filterWithSearchString(builtInModelData.modelId, 1);
Expand Down
102 changes: 78 additions & 24 deletions x-pack/test/functional/services/ml/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import expect from '@kbn/expect';
import { ProvidedType } from '@kbn/test';
import fs from 'fs';
import path from 'path';
import { Calendar } from '@kbn/ml-plugin/server/models/calendar';
import { Annotation } from '@kbn/ml-plugin/common/types/annotations';
import { DataFrameAnalyticsConfig } from '@kbn/ml-plugin/public/application/data_frame_analytics/common';
Expand All @@ -29,6 +28,16 @@ import { FtrProviderContext } from '../../ftr_provider_context';
export type MlApi = ProvidedType<typeof MachineLearningAPIProvider>;

type ModelType = 'regression' | 'classification';
export type TrainedModelName =
| 'pt_tiny_fill_mask'
| 'pt_tiny_ner'
| 'pt_tiny_pass_through'
| 'pt_tiny_text_classification'
| 'pt_tiny_text_embedding'
| 'pt_tiny_zero_shot';
export interface TrainedModelVocabulary {
vocabulary: string[];
}

export function MachineLearningAPIProvider({ getService }: FtrProviderContext) {
const es = getService('es');
Expand Down Expand Up @@ -1135,6 +1144,38 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) {
return model;
},

async createTrainedModelVocabularyES(modelId: string, body: TrainedModelVocabulary) {
log.debug(`Creating vocabulary for trained model "${modelId}"`);
const { body: responseBody, status } = await esSupertest
.put(`/_ml/trained_models/${modelId}/vocabulary`)
.send(body);
this.assertResponseStatusCode(200, status, responseBody);

log.debug('> Trained model vocabulary created');
},

/**
* For the purpose of the functional tests where we only deal with very
* small models, we assume that the model definition can be uploaded as
* one part.
*/
async uploadTrainedModelDefinitionES(modelId: string, modelDefinitionPath: string) {
log.debug(`Uploading definition for trained model "${modelId}"`);

const body = {
total_definition_length: fs.statSync(modelDefinitionPath).size,
definition: fs.readFileSync(modelDefinitionPath).toString('base64'),
total_parts: 1,
};

const { body: responseBody, status } = await esSupertest
.put(`/_ml/trained_models/${modelId}/definition/0`)
.send(body);
this.assertResponseStatusCode(200, status, responseBody);

log.debug('> Trained model definition uploaded');
},

async deleteTrainedModelES(modelId: string) {
log.debug(`Creating trained model with id "${modelId}"`);
const { body: model, status } = await esSupertest.delete(`/_ml/trained_models/${modelId}`);
Expand All @@ -1149,24 +1190,9 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) {
count: number = 10,
withIngestPipelines = false
) {
const compressedDefinition = this.getCompressedModelDefinition(modelType);
const modelIds = new Array(count).fill(null).map((_v, i) => `dfa_${modelType}_model_n_${i}`);

const modelIds = new Array(count).fill(null).map((v, i) => `dfa_${modelType}_model_n_${i}`);

const models = modelIds.map((id) => {
return {
model_id: id,
body: {
compressed_definition: compressedDefinition,
inference_config: {
[modelType]: {},
},
input: {
field_names: ['common_field'],
},
} as PutTrainedModelConfig,
};
});
const models = modelIds.map((id) => this.createTestTrainedModelConfig(id, modelType));

for (const model of models) {
await this.createTrainedModel(model.model_id, model.body);
Expand All @@ -1178,7 +1204,7 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) {
return modelIds;
},

async createTestTrainedModelConfig(modelId: string, modelType: ModelType) {
createTestTrainedModelConfig(modelId: string, modelType: ModelType) {
const compressedDefinition = this.getCompressedModelDefinition(modelType);

return {
Expand All @@ -1201,16 +1227,44 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) {
*/
getCompressedModelDefinition(modelType: ModelType) {
return fs.readFileSync(
path.resolve(
__dirname,
'resources',
'trained_model_definitions',
`minimum_valid_config_${modelType}.json.gz.b64`
require.resolve(
`./resources/trained_model_definitions/minimum_valid_config_${modelType}.json.gz.b64`
),
'utf-8'
);
},

getTrainedModelConfig(modelName: TrainedModelName) {
const configFileContent = fs.readFileSync(
require.resolve(`./resources/trained_model_definitions/${modelName}/config.json`),
'utf-8'
);
return JSON.parse(configFileContent) as PutTrainedModelConfig;
},

getTrainedModelVocabulary(modelName: TrainedModelName) {
const vocabularyFileContent = fs.readFileSync(
require.resolve(`./resources/trained_model_definitions/${modelName}/vocabulary.json`),
'utf-8'
);
return JSON.parse(vocabularyFileContent) as TrainedModelVocabulary;
},

getTrainedModelDefinitionPath(modelName: TrainedModelName) {
return require.resolve(
`./resources/trained_model_definitions/${modelName}/traced_pytorch_model.pt`
);
},

async importTrainedModel(modelId: string, modelName: TrainedModelName) {
await this.createTrainedModel(modelId, this.getTrainedModelConfig(modelName));
await this.createTrainedModelVocabularyES(modelId, this.getTrainedModelVocabulary(modelName));
await this.uploadTrainedModelDefinitionES(
modelId,
this.getTrainedModelDefinitionPath(modelName)
);
},

async createModelAlias(modelId: string, modelAlias: string) {
log.debug(`Creating alias for model "${modelId}"`);
const { body, status } = await esSupertest.put(
Expand Down
2 changes: 1 addition & 1 deletion x-pack/test/functional/services/ml/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ export function MachineLearningProvider(context: FtrProviderContext) {
const testResources = MachineLearningTestResourcesProvider(context, api);
const alerting = MachineLearningAlertingProvider(context, api, commonUI);
const swimLane = SwimLaneProvider(context);
const trainedModels = TrainedModelsProvider(context, api, commonUI);
const trainedModels = TrainedModelsProvider(context, commonUI);
const trainedModelsTable = TrainedModelsTableProvider(context, commonUI);
const mlNodesPanel = MlNodesPanelProvider(context);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"description": "Tiny/Dummy PyTorch model (fill_mask)",
"model_type": "pytorch",
"inference_config": {
"fill_mask": {}
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"vocabulary": [
"[PAD]",
"[UNK]",
"[CLS]",
"[SEP]",
"[MASK]",
"Hello",
"world",
"car",
"bike",
"bee",
"bird",
"and"
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"description": "Tiny/Dummy PyTorch model (ner)",
"model_type": "pytorch",
"inference_config": {
"ner": {}
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"vocabulary": [
"[PAD]",
"[UNK]",
"[CLS]",
"[SEP]",
"[MASK]",
"Hello",
"world",
"car",
"bike",
"bee",
"bird",
"and",
"my",
"name",
"is",
"I'm",
"Spartacus"
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"description": "Tiny/Dummy PyTorch model (pass_through)",
"model_type": "pytorch",
"inference_config": {
"pass_through": {
"tokenization": {
"bert": {
"with_special_tokens": false
}
}
}
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"vocabulary": [
"[UNK]", "[PAD]", "there", "is", "no", "spoon"
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"description": "Tiny/Dummy PyTorch model (text_classification)",
"model_type": "pytorch",
"inference_config": {
"text_classification": {
"classification_labels": ["POSITIVE", "NEGATIVE"]
}
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"vocabulary": [
"[PAD]",
"[UNK]",
"[CLS]",
"[SEP]",
"[MASK]",
"Hello",
"world",
"car",
"bike",
"bee",
"bird",
"and"
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"description": "Tiny/Dummy PyTorch model (text_embedding)",
"model_type": "pytorch",
"inference_config": {
"text_embedding": {}
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"vocabulary": [
"[PAD]",
"[UNK]",
"[CLS]",
"[SEP]",
"[MASK]",
"Hello",
"world",
"car",
"bike",
"bee",
"bird",
"and"
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"description": "Tiny/Dummy PyTorch model (zero_shot)",
"model_type": "pytorch",
"inference_config": {
"zero_shot_classification": {
"classification_labels": ["entailment", "neutral", "contradiction"]
}
}
}
Binary file not shown.
Loading