Skip to content

Commit

Permalink
[ML] Fixing space checks for recently changed trained model apis (ela…
Browse files Browse the repository at this point in the history
…stic#156238)

Fixes issues raised in
elastic#155375 (comment)
Kibana trained model endpoints for `_stop`, `_update` and `infer` now
require the model ID was well as the deployment ID to be passed to them.

Also fixes the stop trained model api when stopping more than one model.
It's very likely the elasticsearch `_stop` api will not support a comma
separated list of deployment IDs for this release, and so this change
calls `_stop` in a loop for each deployment. It also allows for better
reporting if any of the deployments fail to stop.

(cherry picked from commit 9559bee)
  • Loading branch information
jgowdyelastic committed May 3, 2023
1 parent e86a6e1 commit 206ba1b
Show file tree
Hide file tree
Showing 11 changed files with 153 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export function useModelActions({
onTestAction: (model: ModelItem) => void;
onModelsDeleteRequest: (modelsIds: string[]) => void;
onLoading: (isLoading: boolean) => void;
fetchModels: () => void;
fetchModels: () => Promise<void>;
modelAndDeploymentIds: string[];
}): Array<Action<ModelItem>> {
const {
Expand Down Expand Up @@ -236,9 +236,13 @@ export function useModelActions({

try {
onLoading(true);
await trainedModelsApiService.updateModelDeployment(deploymentParams.deploymentId!, {
number_of_allocations: deploymentParams.numOfAllocations,
});
await trainedModelsApiService.updateModelDeployment(
item.model_id,
deploymentParams.deploymentId!,
{
number_of_allocations: deploymentParams.numOfAllocations,
}
);
displaySuccessToast(
i18n.translate('xpack.ml.trainedModels.modelsList.updateSuccess', {
defaultMessage: 'Deployment for "{modelId}" has been updated successfully.',
Expand Down Expand Up @@ -294,19 +298,38 @@ export function useModelActions({

try {
onLoading(true);
await trainedModelsApiService.stopModelAllocation(deploymentIds, {
force: requireForceStop,
});
const results = await trainedModelsApiService.stopModelAllocation(
item.model_id,
deploymentIds,
{
force: requireForceStop,
}
);
displaySuccessToast(
i18n.translate('xpack.ml.trainedModels.modelsList.stopSuccess', {
defaultMessage: 'Deployment for "{modelId}" has been stopped successfully.',
defaultMessage:
'{numberOfDeployments, plural, one {Deployment} other {Deployments}} for "{modelId}" has been stopped successfully.',
values: {
modelId: item.model_id,
numberOfDeployments: deploymentIds.length,
},
})
);
// Need to fetch model state updates
await fetchModels();
if (Object.values(results).some((r) => r.error !== undefined)) {
Object.entries(results).forEach(([id, r]) => {
if (r.error !== undefined) {
displayErrorToast(
r.error,
i18n.translate('xpack.ml.trainedModels.modelsList.stopDeploymentWarning', {
defaultMessage: 'Failed to stop "{deploymentId}"',
values: {
deploymentId: id,
},
})
);
}
});
}
} catch (e) {
displayErrorToast(
e,
Expand All @@ -319,6 +342,8 @@ export function useModelActions({
);
onLoading(false);
}
// Need to fetch model state updates
await fetchModels();
},
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,8 @@ export abstract class InferenceBase<TInferResponse> {
const inferenceConfig = getInferenceConfig();

const resp = (await this.trainedModelsApi.inferTrainedModel(
this.deploymentId ?? this.model.model_id,
this.model.model_id,
this.deploymentId,
{
docs: this.getInferDocs(),
...(inferenceConfig ? { inference_config: inferenceConfig } : {}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';

import { useMemo } from 'react';
import { HttpFetchQuery } from '@kbn/core/public';
import { MlSavedObjectType } from '../../../../common/types/saved_objects';
import type { HttpFetchQuery } from '@kbn/core/public';
import type { ErrorType } from '@kbn/ml-error-utils';
import type { MlSavedObjectType } from '../../../../common/types/saved_objects';
import { HttpService } from '../http_service';
import { basePath } from '.';
import { useMlKibana } from '../../contexts/kibana';
Expand Down Expand Up @@ -142,32 +143,43 @@ export function trainedModelsApiProvider(httpService: HttpService) {
});
},

stopModelAllocation(deploymentsIds: string[], options: { force: boolean } = { force: false }) {
stopModelAllocation(
modelId: string,
deploymentsIds: string[],
options: { force: boolean } = { force: false }
) {
const force = options?.force;

return httpService.http<{ acknowledge: boolean }>({
path: `${apiBasePath}/trained_models/${deploymentsIds.join(',')}/deployment/_stop`,
return httpService.http<Record<string, { acknowledge: boolean; error?: ErrorType }>>({
path: `${apiBasePath}/trained_models/${modelId}/${deploymentsIds.join(
','
)}/deployment/_stop`,
method: 'POST',
query: { force },
});
},

updateModelDeployment(modelId: string, params: { number_of_allocations: number }) {
updateModelDeployment(
modelId: string,
deploymentId: string,
params: { number_of_allocations: number }
) {
return httpService.http<{ acknowledge: boolean }>({
path: `${apiBasePath}/trained_models/${modelId}/deployment/_update`,
path: `${apiBasePath}/trained_models/${modelId}/${deploymentId}/deployment/_update`,
method: 'POST',
body: JSON.stringify(params),
});
},

inferTrainedModel(
modelId: string,
deploymentsId: string,
payload: estypes.MlInferTrainedModelRequest['body'],
timeout?: string
) {
const body = JSON.stringify(payload);
return httpService.http<estypes.MlInferTrainedModelResponse>({
path: `${apiBasePath}/trained_models/infer/${modelId}`,
path: `${apiBasePath}/trained_models/infer/${modelId}/${deploymentsId}`,
method: 'POST',
body,
...(timeout ? { query: { timeout } as HttpFetchQuery } : {}),
Expand Down
22 changes: 20 additions & 2 deletions x-pack/plugins/ml/server/lib/ml_client/ml_client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,17 @@ export function getMlClient(
}
}

function switchDeploymentId(
p: Parameters<MlClient['stopTrainedModelDeployment']>
): Parameters<MlClient['stopTrainedModelDeployment']> {
const [params] = p;
if (params.deployment_id !== undefined) {
params.model_id = params.deployment_id;
delete params.deployment_id;
}
return p;
}

async function checkModelIds(modelIds: string[], allowWildcards: boolean = false) {
const filteredModelIds = await mlSavedObjectService.filterTrainedModelIdsForSpace(modelIds);
let missingIds = modelIds.filter((j) => filteredModelIds.indexOf(j) === -1);
Expand Down Expand Up @@ -491,17 +502,24 @@ export function getMlClient(
return mlClient.startTrainedModelDeployment(...p);
},
async updateTrainedModelDeployment(...p: Parameters<MlClient['updateTrainedModelDeployment']>) {
const { model_id: modelId, number_of_allocations: numberOfAllocations } = p[0];
await modelIdsCheck(p);

const { deployment_id: deploymentId, number_of_allocations: numberOfAllocations } = p[0];
return client.asInternalUser.transport.request({
method: 'POST',
path: `/_ml/trained_models/${modelId}/deployment/_update`,
path: `/_ml/trained_models/${deploymentId}/deployment/_update`,
body: { number_of_allocations: numberOfAllocations },
});
},
async stopTrainedModelDeployment(...p: Parameters<MlClient['stopTrainedModelDeployment']>) {
await modelIdsCheck(p);
switchDeploymentId(p);

return mlClient.stopTrainedModelDeployment(...p);
},
async inferTrainedModel(...p: Parameters<MlClient['inferTrainedModel']>) {
await modelIdsCheck(p);
switchDeploymentId(p);
// Temporary workaround for the incorrect inferTrainedModelDeployment function in the esclient
if (
// @ts-expect-error TS complains it's always false
Expand Down
25 changes: 23 additions & 2 deletions x-pack/plugins/ml/server/lib/ml_client/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,44 @@
* 2.0.
*/

import { ElasticsearchClient } from '@kbn/core/server';
import type { TransportRequestOptionsWithMeta } from '@elastic/elasticsearch';
import type * as estypes from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import type { ElasticsearchClient } from '@kbn/core/server';
import { searchProvider } from './search';

type OrigMlClient = ElasticsearchClient['ml'];
export interface UpdateTrainedModelDeploymentRequest {
model_id: string;
deployment_id?: string;
number_of_allocations: number;
}
export interface UpdateTrainedModelDeploymentResponse {
acknowledge: boolean;
}

export interface MlClient extends OrigMlClient {
export interface MlStopTrainedModelDeploymentRequest
extends estypes.MlStopTrainedModelDeploymentRequest {
deployment_id?: string;
}

export interface MlInferTrainedModelRequest extends estypes.MlInferTrainedModelRequest {
deployment_id?: string;
}

export interface MlClient
extends Omit<OrigMlClient, 'stopTrainedModelDeployment' | 'inferTrainedModel'> {
anomalySearch: ReturnType<typeof searchProvider>['anomalySearch'];
updateTrainedModelDeployment: (
payload: UpdateTrainedModelDeploymentRequest
) => Promise<UpdateTrainedModelDeploymentResponse>;
stopTrainedModelDeployment: (
p: MlStopTrainedModelDeploymentRequest,
options?: TransportRequestOptionsWithMeta
) => Promise<estypes.MlStopTrainedModelDeploymentResponse>;
inferTrainedModel: (
p: MlInferTrainedModelRequest,
options?: TransportRequestOptionsWithMeta
) => Promise<estypes.MlInferTrainedModelResponse>;
}

export type MlClientParams =
Expand Down
17 changes: 17 additions & 0 deletions x-pack/plugins/ml/server/routes/schemas/inference_schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ export const modelIdSchema = schema.object({
modelId: schema.string(),
});

export const modelAndDeploymentIdSchema = schema.object({
/**
* Model ID
*/
modelId: schema.string(),
/**
* Deployment ID
*/
deploymentId: schema.string(),
});

export const threadingParamsSchema = schema.maybe(
schema.object({
number_of_allocations: schema.number(),
Expand Down Expand Up @@ -55,3 +66,9 @@ export const pipelineSimulateBody = schema.object({
docs: schema.arrayOf(schema.any()),
});
export const pipelineDocs = schema.arrayOf(schema.string());

export const stopDeploymentSchema = schema.object({
modelId: schema.string(),
/** force stop */
force: schema.maybe(schema.boolean()),
});
44 changes: 30 additions & 14 deletions x-pack/plugins/ml/server/routes/trained_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
*/

import { schema } from '@kbn/config-schema';
import { ErrorType } from '@kbn/ml-error-utils';
import { RouteInitialization } from '../types';
import { wrapError } from '../client/error_wrapper';
import {
getInferenceQuerySchema,
inferTrainedModelBody,
inferTrainedModelQuery,
modelAndDeploymentIdSchema,
modelIdSchema,
optionalModelIdSchema,
pipelineSimulateBody,
Expand Down Expand Up @@ -327,9 +329,9 @@ export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization)
*/
router.post(
{
path: '/api/ml/trained_models/{modelId}/deployment/_update',
path: '/api/ml/trained_models/{modelId}/{deploymentId}/deployment/_update',
validate: {
params: modelIdSchema,
params: modelAndDeploymentIdSchema,
body: updateDeploymentParamsSchema,
},
options: {
Expand All @@ -338,9 +340,10 @@ export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization)
},
routeGuard.fullLicenseAPIGuard(async ({ mlClient, request, response }) => {
try {
const { modelId } = request.params;
const { modelId, deploymentId } = request.params;
const body = await mlClient.updateTrainedModelDeployment({
model_id: modelId,
deployment_id: deploymentId,
...request.body,
});
return response.ok({
Expand All @@ -361,9 +364,9 @@ export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization)
*/
router.post(
{
path: '/api/ml/trained_models/{modelId}/deployment/_stop',
path: '/api/ml/trained_models/{modelId}/{deploymentId}/deployment/_stop',
validate: {
params: modelIdSchema,
params: modelAndDeploymentIdSchema,
query: forceQuerySchema,
},
options: {
Expand All @@ -372,13 +375,25 @@ export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization)
},
routeGuard.fullLicenseAPIGuard(async ({ mlClient, request, response }) => {
try {
const { modelId } = request.params;
const body = await mlClient.stopTrainedModelDeployment({
model_id: modelId,
force: request.query.force ?? false,
});
const { deploymentId, modelId } = request.params;

const results: Record<string, { success: boolean; error?: ErrorType }> = {};

for (const id of deploymentId.split(',')) {
try {
const { stopped: success } = await mlClient.stopTrainedModelDeployment({
model_id: modelId,
deployment_id: id,
force: request.query.force ?? false,
allow_no_match: false,
});
results[id] = { success };
} catch (error) {
results[id] = { success: false, error };
}
}
return response.ok({
body,
body: results,
});
} catch (e) {
return response.customError(wrapError(e));
Expand Down Expand Up @@ -428,9 +443,9 @@ export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization)
*/
router.post(
{
path: '/api/ml/trained_models/infer/{modelId}',
path: '/api/ml/trained_models/infer/{modelId}/{deploymentId}',
validate: {
params: modelIdSchema,
params: modelAndDeploymentIdSchema,
query: inferTrainedModelQuery,
body: inferTrainedModelBody,
},
Expand All @@ -440,9 +455,10 @@ export function trainedModelsRoutes({ router, routeGuard }: RouteInitialization)
},
routeGuard.fullLicenseAPIGuard(async ({ mlClient, request, response }) => {
try {
const { modelId } = request.params;
const { modelId, deploymentId } = request.params;
const body = await mlClient.inferTrainedModel({
model_id: modelId,
deployment_id: deploymentId,
body: {
docs: request.body.docs,
...(request.body.inference_config
Expand Down
Loading

0 comments on commit 206ba1b

Please sign in to comment.