Skip to content

Commit

Permalink
add validation for algorithm name
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Aug 24, 2023
1 parent b1f4e27 commit 4ba442d
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam
throw new Error('Must define either an algorithm name or training image URI in the algorithm specification');
}

// check that both algorithm name and image are not defined
if (props.algorithmSpecification.algorithmName && props.algorithmSpecification.trainingImage) {
throw new Error('Cannot define both an algorithm name and training image URI in the algorithm specification');
}

// validate algorithm name
this.validateAlgorithmName(props.algorithmSpecification.algorithmName);

// set the input mode to 'File' if not defined
this.algorithmSpecification = props.algorithmSpecification.trainingInputMode
? props.algorithmSpecification
Expand Down Expand Up @@ -324,6 +332,21 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam
: {};
}

private validateAlgorithmName(algorithmName?: string): void {
if (algorithmName === undefined) {
return;
}

if (algorithmName.length < 1 || 170 < algorithmName.length) {
throw new Error(`Algorithm name length must be between 1 and 170, but got ${algorithmName.length}`);
}

const regex = /^(arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:[a-z\-]*\/)?([a-zA-Z0-9]([a-zA-Z0-9-]){0,62})(?<!-)$/;
if (!regex.test(algorithmName)) {
throw new Error(`Expected algorithm name to match pattern ${regex.source}, but got ${algorithmName}`);
}
}

private makePolicyStatements(): iam.PolicyStatement[] {
// set the sagemaker role or create new one
this._grantPrincipal = this._role =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,146 @@ test('Cannot create a SageMaker train task with both algorithm name and image na
}))
.toThrowError(/Must define either an algorithm name or training image URI in the algorithm specification/);
});

test('Cannot create a SageMaker train task with both algorithm name and image name defined', () => {

expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
algorithmName: 'BlazingText',
trainingImage: tasks.DockerImage.fromJsonExpression(sfn.JsonPath.stringAt('$.Training.imageName')),
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
}))
.toThrowError(/Cannot define both an algorithm name and training image URI in the algorithm specification/);
});

test('create a SageMaker train task with trainingImage', () => {

const task = new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
trainingImage: tasks.DockerImage.fromJsonExpression(sfn.JsonPath.stringAt('$.Training.imageName')),
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
});

// THEN
expect(stack.resolve(task.toStateJson())).toMatchObject({
Parameters: {
AlgorithmSpecification: {
'TrainingImage.$': '$.Training.imageName',
'TrainingInputMode': 'File',
},
},
});
});

test('create a SageMaker train task with image URI algorithmName', () => {

const task = new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
algorithmName: 'arn:aws:sagemaker:us-east-1:123456789012:algorithm/scikit-decision-trees',
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
});

// THEN
expect(stack.resolve(task.toStateJson())).toMatchObject({
Parameters: {
AlgorithmSpecification: {
AlgorithmName: 'arn:aws:sagemaker:us-east-1:123456789012:algorithm/scikit-decision-trees',
},
},
});
});

test('Cannot create a SageMaker train task when algorithmName length is 171 or more', () => {

expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
algorithmName: 'a'.repeat(171), // maximum length is 170
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
}))
.toThrowError(/Algorithm name length must be between 1 and 170, but got 171/);
});

test('Cannot create a SageMaker train task with incorrect algorithmName', () => {

expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', {
trainingJobName: 'myTrainJob',
algorithmSpecification: {
algorithmName: 'Blazing_Text', // underscores are not allowed
},
inputDataConfig: [
{
channelName: 'train',
dataSource: {
s3DataSource: {
s3DataType: tasks.S3DataType.S3_PREFIX,
s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'),
},
},
},
],
outputDataConfig: {
s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'),
},
}))
.toThrowError(/Expected algorithm name to match pattern/);
});

0 comments on commit 4ba442d

Please sign in to comment.