Skip to content

Commit

Permalink
Merge branch 'main' into 136039-rules-status-deploy
Browse files Browse the repository at this point in the history
  • Loading branch information
kibanamachine authored Nov 15, 2022
2 parents 6872e67 + d32e130 commit ed54bb8
Show file tree
Hide file tree
Showing 18 changed files with 748 additions and 167 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
* 2.0.
*/

import { IngestSetProcessor, MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/types';
import {
IngestSetProcessor,
MlTrainedModelConfig,
MlTrainedModelStats,
} from '@elastic/elasticsearch/lib/api/types';
import { BUILT_IN_MODEL_TAG } from '@kbn/ml-plugin/common/constants/data_frame_analytics';
import { SUPPORTED_PYTORCH_TASKS } from '@kbn/ml-plugin/common/constants/trained_models';

import { MlInferencePipeline } from '../types/pipelines';
import { MlInferencePipeline, TrainedModelState } from '../types/pipelines';

import {
BUILT_IN_MODEL_TAG as LOCAL_BUILT_IN_MODEL_TAG,
Expand All @@ -18,6 +22,8 @@ import {
getSetProcessorForInferenceType,
SUPPORTED_PYTORCH_TASKS as LOCAL_SUPPORTED_PYTORCH_TASKS,
parseMlInferenceParametersFromPipeline,
parseModelStateFromStats,
parseModelStateReasonFromStats,
} from '.';

const mockModel: MlTrainedModelConfig = {
Expand Down Expand Up @@ -241,3 +247,80 @@ describe('parseMlInferenceParametersFromPipeline', () => {
).toBeNull();
});
});

describe('parseModelStateFromStats', () => {
it('returns not deployed for undefined stats', () => {
expect(parseModelStateFromStats()).toEqual(TrainedModelState.NotDeployed);
});
it('returns Started', () => {
expect(
parseModelStateFromStats({
deployment_stats: {
state: 'started',
},
} as unknown as MlTrainedModelStats)
).toEqual(TrainedModelState.Started);
});
it('returns Starting', () => {
expect(
parseModelStateFromStats({
deployment_stats: {
state: 'starting',
},
} as unknown as MlTrainedModelStats)
).toEqual(TrainedModelState.Starting);
});
it('returns Stopping', () => {
expect(
parseModelStateFromStats({
deployment_stats: {
state: 'stopping',
},
} as unknown as MlTrainedModelStats)
).toEqual(TrainedModelState.Stopping);
});
it('returns Failed', () => {
expect(
parseModelStateFromStats({
deployment_stats: {
state: 'failed',
},
} as unknown as MlTrainedModelStats)
).toEqual(TrainedModelState.Failed);
});
it('returns not deployed if an unknown state is received', () => {
expect(
parseModelStateFromStats({
deployment_stats: {
state: 'other thing',
},
} as unknown as MlTrainedModelStats)
).toEqual(TrainedModelState.NotDeployed);
});
});

describe('parseModelStateReasonFromStats', () => {
it('returns reason from deployment_stats', () => {
const reason = 'This is the reason the model is in a failed state';
expect(
parseModelStateReasonFromStats({
deployment_stats: {
reason,
state: 'failed',
},
} as unknown as MlTrainedModelStats)
).toEqual(reason);
});
it('returns undefined if reason not found from deployment_stats', () => {
expect(
parseModelStateReasonFromStats({
deployment_stats: {
state: 'failed',
},
} as unknown as MlTrainedModelStats)
).toBeUndefined();
});
it('returns undefined stats is undefined', () => {
expect(parseModelStateReasonFromStats(undefined)).toBeUndefined();
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@ import {
IngestPipeline,
IngestSetProcessor,
MlTrainedModelConfig,
MlTrainedModelStats,
} from '@elastic/elasticsearch/lib/api/types';

import { MlInferencePipeline, CreateMlInferencePipelineParameters } from '../types/pipelines';
import {
MlInferencePipeline,
CreateMlInferencePipelineParameters,
TrainedModelState,
} from '../types/pipelines';

// Getting an error importing this from @kbn/ml-plugin/common/constants/data_frame_analytics'
// So defining it locally for now with a test to make sure it matches.
Expand Down Expand Up @@ -177,3 +182,22 @@ export const parseMlInferenceParametersFromPipeline = (
source_field: sourceField,
};
};

export const parseModelStateFromStats = (trainedModelStats?: Partial<MlTrainedModelStats>) => {
switch (trainedModelStats?.deployment_stats?.state) {
case 'started':
return TrainedModelState.Started;
case 'starting':
return TrainedModelState.Starting;
case 'stopping':
return TrainedModelState.Stopping;
// @ts-ignore: type is wrong, "failed" is a possible state
case 'failed':
return TrainedModelState.Failed;
default:
return TrainedModelState.NotDeployed;
}
};

export const parseModelStateReasonFromStats = (trainedModelStats?: Partial<MlTrainedModelStats>) =>
trainedModelStats?.deployment_stats?.reason;
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import {
MlTrainedModelDeploymentStats,
MlTrainedModelStats,
} from '@elastic/elasticsearch/lib/api/types';
import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_models';

export const nerModel: TrainedModelConfigResponse = {
inference_config: {
ner: {
classification_labels: [
'O',
'B_PER',
'I_PER',
'B_ORG',
'I_ORG',
'B_LOC',
'I_LOC',
'B_MISC',
'I_MISC',
],
tokenization: {
bert: {
do_lower_case: false,
max_sequence_length: 512,
span: -1,
truncate: 'first',
with_special_tokens: true,
},
},
},
},
input: {
field_names: ['text_field'],
},
model_id: 'ner-mocked-model',
model_type: 'pytorch',
tags: [],
version: '1',
};

export const textClassificationModel: TrainedModelConfigResponse = {
inference_config: {
text_classification: {
classification_labels: ['anger', 'disgust', 'fear', 'joy', 'neutral', 'sadness', 'surprise'],
num_top_classes: 0,
tokenization: {
roberta: {
add_prefix_space: false,
do_lower_case: false,
max_sequence_length: 512,
span: -1,
truncate: 'first',
with_special_tokens: true,
},
},
},
},
input: {
field_names: ['text_field'],
},
model_id: 'text-classification-mocked-model',
model_type: 'pytorch',
tags: [],
version: '2',
};

export const mlModels: TrainedModelConfigResponse[] = [nerModel, textClassificationModel];

export const mlModelStats: {
count: number;
trained_model_stats: MlTrainedModelStats[];
} = {
count: 2,
trained_model_stats: [
{
model_id: nerModel.model_id,
model_size_stats: {
model_size_bytes: 260831121,
required_native_memory_bytes: 773320482,
},
pipeline_count: 0,
deployment_stats: {
allocation_status: {
allocation_count: 1,
target_allocation_count: 1,
state: 'fully_allocated',
},
error_count: 0,
inference_count: 0,
nodes: [],
number_of_allocations: 1,
state: 'started',
threads_per_allocation: 1,
} as unknown as MlTrainedModelDeploymentStats,
},
{
deployment_stats: {
allocation_status: {
allocation_count: 1,
target_allocation_count: 1,
state: 'fully_allocated',
},
error_count: 0,
inference_count: 0,
nodes: [],
number_of_allocations: 1,
state: 'started',
threads_per_allocation: 1,
} as unknown as MlTrainedModelDeploymentStats,
model_id: textClassificationModel.model_id,
model_size_stats: {
model_size_bytes: 260831121,
required_native_memory_bytes: 773320482,
},
pipeline_count: 0,
},
],
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { mockHttpValues } from '../../../__mocks__/kea_logic';
import { mlModelStats } from '../../__mocks__/ml_models.mock';

import { getMLModelsStats } from './ml_model_stats_logic';

describe('MLModelsApiLogic', () => {
const { http } = mockHttpValues;
beforeEach(() => {
jest.clearAllMocks();
});
describe('getMLModelsStats', () => {
it('calls the ml api', async () => {
http.get.mockResolvedValue(mlModelStats);
const result = await getMLModelsStats();
expect(http.get).toHaveBeenCalledWith('/api/ml/trained_models/_stats');
expect(result).toEqual(mlModelStats);
});
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
import { MlTrainedModelStats } from '@elastic/elasticsearch/lib/api/types';

import { Actions, createApiLogic } from '../../../shared/api_logic/create_api_logic';
import { HttpLogic } from '../../../shared/http';

export type GetMlModelsStatsArgs = undefined;

export interface GetMlModelsStatsResponse {
count: number;
trained_model_stats: MlTrainedModelStats[];
}

export const getMLModelsStats = async () => {
return await HttpLogic.values.http.get<GetMlModelsStatsResponse>('/api/ml/trained_models/_stats');
};

export const MLModelsStatsApiLogic = createApiLogic(
['ml_models_stats_api_logic'],
getMLModelsStats
);

export type MLModelsStatsApiLogicActions = Actions<GetMlModelsStatsArgs, GetMlModelsStatsResponse>;
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
* 2.0.
*/
import { mockHttpValues } from '../../../__mocks__/kea_logic';

import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_models';
import { mlModels } from '../../__mocks__/ml_models.mock';

import { getMLModels } from './ml_models_logic';

Expand All @@ -17,55 +16,12 @@ describe('MLModelsApiLogic', () => {
});
describe('getMLModels', () => {
it('calls the ml api', async () => {
const response: Promise<TrainedModelConfigResponse[]> = Promise.resolve([
{
inference_config: {},
input: {
field_names: [],
},
model_id: 'a-model-001',
model_type: 'pytorch',
tags: ['pytorch', 'ner'],
version: '1',
},
{
inference_config: {},
input: {
field_names: [],
},
model_id: 'a-model-002',
model_type: 'lang_ident',
tags: [],
version: '2',
},
]);
http.get.mockReturnValue(response);
http.get.mockResolvedValue(mlModels);
const result = await getMLModels();
expect(http.get).toHaveBeenCalledWith('/api/ml/trained_models', {
query: { size: 1000, with_pipelines: true },
});
expect(result).toEqual([
{
inference_config: {},
input: {
field_names: [],
},
model_id: 'a-model-001',
model_type: 'pytorch',
tags: ['pytorch', 'ner'],
version: '1',
},
{
inference_config: {},
input: {
field_names: [],
},
model_id: 'a-model-002',
model_type: 'lang_ident',
tags: [],
version: '2',
},
]);
expect(result).toEqual(mlModels);
});
});
});
Loading

0 comments on commit ed54bb8

Please sign in to comment.