Skip to content

Commit

Permalink
feat(stepfunctions-tasks): algorithmName validation for `SageMakerC…
Browse files Browse the repository at this point in the history
…reateTrainingJob` (#26877)

Referencing PR #26675, I have added validation for the `algorithmName` parameter in `SageMakerCreateTrainingJob`.
However, it was suggested that changes for validation should be separated.  So, I have created this PR.

Docs for `algorithmName`:
https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_AlgorithmSpecification.html#API_AlgorithmSpecification_Contents


Exemption Request:
This change does not alter the behavior.
I believe the unit test `create-training-job.test.ts` that I have added is sufficient to test this change.

----

*By submitting this pull request, I confirm that my contribution is made under the terms of the Apache-2.0 license*
  • Loading branch information
tmyoda authored Aug 29, 2023
1 parent 4fd510e commit 1cead3b
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { renderEnvironment, renderTags } from './private/utils';
import * as ec2 from '../../../aws-ec2';
import * as iam from '../../../aws-iam';
import * as sfn from '../../../aws-stepfunctions';
import { Duration, Lazy, Size, Stack } from '../../../core';
import { Duration, Lazy, Size, Stack, Token } from '../../../core';
import { integrationResourceArn, validatePatternSupported } from '../private/task-utils';

/**
Expand Down Expand Up @@ -163,6 +163,14 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam
throw new Error('Must define either an algorithm name or training image URI in the algorithm specification');
}

// check that both algorithm name and image are not defined
if (props.algorithmSpecification.algorithmName && props.algorithmSpecification.trainingImage) {
throw new Error('Cannot define both an algorithm name and training image URI in the algorithm specification');
}

// validate algorithm name
this.validateAlgorithmName(props.algorithmSpecification.algorithmName);

// set the input mode to 'File' if not defined
this.algorithmSpecification = props.algorithmSpecification.trainingInputMode
? props.algorithmSpecification
Expand Down Expand Up @@ -324,6 +332,21 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam
: {};
}

private validateAlgorithmName(algorithmName?: string): void {
if (algorithmName === undefined || Token.isUnresolved(algorithmName)) {
return;
}

if (algorithmName.length < 1 || 170 < algorithmName.length) {
throw new Error(`Algorithm name length must be between 1 and 170, but got ${algorithmName.length}`);
}

const regex = /^(arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:[a-z\-]*\/)?([a-zA-Z0-9]([a-zA-Z0-9-]){0,62})(?<!-)$/;
if (!regex.test(algorithmName)) {
throw new Error(`Expected algorithm name to match pattern ${regex.source}, but got ${algorithmName}`);
}
}

private makePolicyStatements(): iam.PolicyStatement[] {
// set the sagemaker role or create new one
this._grantPrincipal = this._role =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,146 @@ test('Cannot create a SageMaker train task with both algorithm name and image na
}))
.toThrowError(/Must define either an algorithm name or training image URI in the algorithm specification/);
});

test('Cannot create a SageMaker train task with both algorithm name and image name defined', () => {

expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
algorithmName: 'BlazingText',
trainingImage: tasks.DockerImage.fromJsonExpression(sfn.JsonPath.stringAt('$.Training.imageName')),
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
}))
.toThrowError(/Cannot define both an algorithm name and training image URI in the algorithm specification/);
});

test('create a SageMaker train task with trainingImage', () => {

const task = new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
trainingImage: tasks.DockerImage.fromJsonExpression(sfn.JsonPath.stringAt('$.Training.imageName')),
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
});

// THEN
expect(stack.resolve(task.toStateJson())).toMatchObject({
Parameters: {
AlgorithmSpecification: {
'TrainingImage.$': '$.Training.imageName',
'TrainingInputMode': 'File',
},
},
});
});

test('create a SageMaker train task with image URI algorithmName', () => {

const task = new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
algorithmName: 'arn:aws:sagemaker:us-east-1:123456789012:algorithm/scikit-decision-trees',
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
});

// THEN
expect(stack.resolve(task.toStateJson())).toMatchObject({
Parameters: {
AlgorithmSpecification: {
AlgorithmName: 'arn:aws:sagemaker:us-east-1:123456789012:algorithm/scikit-decision-trees',
},
},
});
});

test('Cannot create a SageMaker train task when algorithmName length is 171 or more', () => {

expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
algorithmName: 'a'.repeat(171), // maximum length is 170
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
}))
.toThrowError(/Algorithm name length must be between 1 and 170, but got 171/);
});

test('Cannot create a SageMaker train task with incorrect algorithmName', () => {

expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
algorithmName: 'Blazing_Text', // underscores are not allowed
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
}))
.toThrowError(/Expected algorithm name to match pattern/);
});

0 comments on commit 1cead3b

Please sign in to comment.