diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/README.md b/packages/@aws-cdk/aws-stepfunctions-tasks/README.md index c8482f9e57f09..e0e89b4ecd924 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/README.md +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/README.md @@ -617,37 +617,33 @@ Step Functions supports [AWS SageMaker](https://docs.aws.amazon.com/step-functio You can call the [`CreateTrainingJob`](https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html) API from a `Task` state. ```ts -new sfn.Task(stack, 'TrainSagemaker', { - task: new tasks.SagemakerTrainTask({ - trainingJobName: sfn.Data.stringAt('$.JobName'), - role, - algorithmSpecification: { - algorithmName: 'BlazingText', - trainingInputMode: tasks.InputMode.FILE, - }, - inputDataConfig: [ - { - channelName: 'train', - dataSource: { - s3DataSource: { - s3DataType: tasks.S3DataType.S3_PREFIX, - s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), - }, - }, +new sfn.SagemakerTrainTask(this, 'TrainSagemaker', { + trainingJobName: sfn.Data.stringAt('$.JobName'), + role, + algorithmSpecification: { + algorithmName: 'BlazingText', + trainingInputMode: tasks.InputMode.FILE, + }, + inputDataConfig: [{ + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), }, - ], - outputDataConfig: { - s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath'), - }, - resourceConfig: { - instanceCount: 1, - instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), - volumeSizeInGB: 50, }, - stoppingCondition: { - maxRuntime: cdk.Duration.hours(1), - }, - }), + }], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath'), + }, + resourceConfig: { + instanceCount: 1, + instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), + volumeSize: cdk.Size.gibibytes(50), + }, + stoppingCondition: { + maxRuntime: cdk.Duration.hours(1), + }, }); ``` @@ -656,29 +652,27 @@ new sfn.Task(stack, 'TrainSagemaker', { You can call the [`CreateTransformJob`](https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html) API from a `Task` state. ```ts -const transformJob = new tasks.SagemakerTransformTask( - transformJobName: "MyTransformJob", - modelName: "MyModelName", - role, - transformInput: { - transformDataSource: { - s3DataSource: { - s3Uri: 's3://inputbucket/train', - s3DataType: S3DataType.S3Prefix, - } - } - }, - transformOutput: { - s3OutputPath: 's3://outputbucket/TransformJobOutputPath', - }, - transformResources: { - instanceCount: 1, - instanceType: ec2.InstanceType.of(ec2.InstanceClass.M4, ec2.InstanceSize.XLarge), +new sfn.SagemakerTransformTask(this, 'Batch Inference', { + transformJobName: 'MyTransformJob', + modelName: 'MyModelName', + role, + transformInput: { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/train', + s3DataType: S3DataType.S3Prefix, + } + } + }, + transformOutput: { + s3OutputPath: 's3://outputbucket/TransformJobOutputPath', + }, + transformResources: { + instanceCount: 1, + instanceType: ec2.InstanceType.of(ec2.InstanceClass.M4, ec2.InstanceSize.XLarge), + } }); -const task = new sfn.Task(this, 'Batch Inference', { - task: transformJob -}); ``` ## SNS diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts index b9a5cd0a9f062..4dad4bf2c295c 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts @@ -10,9 +10,9 @@ export * from './sqs/send-to-queue'; export * from './sqs/send-message'; export * from './ecs/run-ecs-ec2-task'; export * from './ecs/run-ecs-fargate-task'; -export * from './sagemaker/sagemaker-task-base-types'; -export * from './sagemaker/sagemaker-train-task'; -export * from './sagemaker/sagemaker-transform-task'; +export * from './sagemaker/base-types'; +export * from './sagemaker/create-training-job'; +export * from './sagemaker/create-transform-job'; export * from './start-execution'; export * from './stepfunctions/start-execution'; export * from './evaluate-expression'; diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-task-base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts similarity index 98% rename from packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-task-base-types.ts rename to packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts index 1db442e348f75..6f1c5f03dcc37 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-task-base-types.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts @@ -5,13 +5,13 @@ import * as iam from '@aws-cdk/aws-iam'; import * as kms from '@aws-cdk/aws-kms'; import * as s3 from '@aws-cdk/aws-s3'; import * as sfn from '@aws-cdk/aws-stepfunctions'; -import { Construct, Duration } from '@aws-cdk/core'; +import { Construct, Duration, Size } from '@aws-cdk/core'; /** * Task to train a machine learning model using Amazon SageMaker * @experimental */ -export interface ISageMakerTask extends sfn.IStepFunctionsTask, iam.IGrantable {} +export interface ISageMakerTask extends iam.IGrantable {} /** * Specify the training algorithm and algorithm-specific metadata @@ -230,7 +230,7 @@ export interface ResourceConfig { * * @default 10 GB EBS volume. */ - readonly volumeSizeInGB: number; + readonly volumeSize: Size; } /** @@ -622,7 +622,7 @@ export interface TransformResources { * * @default - None */ - readonly volumeKmsKeyId?: kms.Key; + readonly volumeEncryptionKey?: kms.IKey; } /** diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-train-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts similarity index 52% rename from packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-train-task.ts rename to packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts index 758e8a065dc8c..f541a0e692a4f 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-train-task.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts @@ -1,18 +1,16 @@ import * as ec2 from '@aws-cdk/aws-ec2'; import * as iam from '@aws-cdk/aws-iam'; import * as sfn from '@aws-cdk/aws-stepfunctions'; -import { Duration, Lazy, Stack } from '@aws-cdk/core'; -import { getResourceArn } from '../resource-arn-suffix'; -import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig, - S3DataType, StoppingCondition, VpcConfig } from './sagemaker-task-base-types'; +import { Construct, Duration, Lazy, Size, Stack } from '@aws-cdk/core'; +import { integrationResourceArn, validatePatternSupported } from '../private/task-utils'; +import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig, S3DataType, StoppingCondition, VpcConfig } from './base-types'; /** * Properties for creating an Amazon SageMaker training job * * @experimental */ -export interface SagemakerTrainTaskProps { - +export interface SageMakerCreateTrainingJobProps extends sfn.TaskStateBaseProps { /** * Training Job Name. */ @@ -24,19 +22,10 @@ export interface SagemakerTrainTaskProps { * * See https://docs.aws.amazon.com/fr_fr/sagemaker/latest/dg/sagemaker-roles.html#sagemaker-roles-createtrainingjob-perms * - * @default - a role with appropriate permissions will be created. + * @default - a role will be created. */ readonly role?: iam.IRole; - /** - * The service integration pattern indicates different ways to call SageMaker APIs. - * - * The valid value is either FIRE_AND_FORGET or SYNC. - * - * @default FIRE_AND_FORGET - */ - readonly integrationPattern?: sfn.ServiceIntegrationPattern; - /** * Identifies the training algorithm to use. */ @@ -49,7 +38,7 @@ export interface SagemakerTrainTaskProps { * * @default - No hyperparameters */ - readonly hyperparameters?: {[key: string]: any}; + readonly hyperparameters?: { [key: string]: any }; /** * Describes the various datasets (e.g. train, validation, test) and the Amazon S3 location where stored. @@ -61,7 +50,7 @@ export interface SagemakerTrainTaskProps { * * @default - No tags */ - readonly tags?: {[key: string]: string}; + readonly tags?: { [key: string]: string }; /** * Identifies the Amazon S3 location where you want Amazon SageMaker to save the results of model training. @@ -95,13 +84,20 @@ export interface SagemakerTrainTaskProps { * * @experimental */ -export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn.IStepFunctionsTask { +export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam.IGrantable, ec2.IConnectable { + private static readonly SUPPORTED_INTEGRATION_PATTERNS: sfn.IntegrationPattern[] = [ + sfn.IntegrationPattern.REQUEST_RESPONSE, + sfn.IntegrationPattern.RUN_JOB, + ]; /** * Allows specify security group connections for instances of this fleet. */ public readonly connections: ec2.Connections = new ec2.Connections(); + protected readonly taskPolicies?: iam.PolicyStatement[]; + protected readonly taskMetrics?: sfn.TaskMetricsConfig; + /** * The Algorithm Specification */ @@ -126,27 +122,21 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn private securityGroup?: ec2.ISecurityGroup; private readonly securityGroups: ec2.ISecurityGroup[] = []; private readonly subnets?: string[]; - private readonly integrationPattern: sfn.ServiceIntegrationPattern; + private readonly integrationPattern: sfn.IntegrationPattern; private _role?: iam.IRole; private _grantPrincipal?: iam.IPrincipal; - constructor(private readonly props: SagemakerTrainTaskProps) { - this.integrationPattern = props.integrationPattern || sfn.ServiceIntegrationPattern.FIRE_AND_FORGET; + constructor(scope: Construct, id: string, private readonly props: SageMakerCreateTrainingJobProps) { + super(scope, id, props); - const supportedPatterns = [ - sfn.ServiceIntegrationPattern.FIRE_AND_FORGET, - sfn.ServiceIntegrationPattern.SYNC, - ]; - - if (!supportedPatterns.includes(this.integrationPattern)) { - throw new Error(`Invalid Service Integration Pattern: ${this.integrationPattern} is not supported to call SageMaker.`); - } + this.integrationPattern = props.integrationPattern || sfn.IntegrationPattern.REQUEST_RESPONSE; + validatePatternSupported(this.integrationPattern, SageMakerCreateTrainingJob.SUPPORTED_INTEGRATION_PATTERNS); // set the default resource config if not defined. this.resourceConfig = props.resourceConfig || { instanceCount: 1, instanceType: ec2.InstanceType.of(ec2.InstanceClass.M4, ec2.InstanceSize.XLARGE), - volumeSizeInGB: 10, + volumeSize: Size.gibibytes(10), }; // set the stopping condition if not defined @@ -155,20 +145,22 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn }; // check that either algorithm name or image is defined - if ((!props.algorithmSpecification.algorithmName) && (!props.algorithmSpecification.trainingImage)) { + if (!props.algorithmSpecification.algorithmName && !props.algorithmSpecification.trainingImage) { throw new Error('Must define either an algorithm name or training image URI in the algorithm specification'); } // set the input mode to 'File' if not defined - this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ? - ( props.algorithmSpecification ) : - ( { ...props.algorithmSpecification, trainingInputMode: InputMode.FILE } ); + this.algorithmSpecification = props.algorithmSpecification.trainingInputMode + ? props.algorithmSpecification + : { ...props.algorithmSpecification, trainingInputMode: InputMode.FILE }; // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined - this.inputDataConfig = props.inputDataConfig.map(config => { + this.inputDataConfig = props.inputDataConfig.map((config) => { if (!config.dataSource.s3DataSource.s3DataType) { - return Object.assign({}, config, { dataSource: { s3DataSource: - { ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } } }); + return { + ...config, + dataSource: { s3DataSource: { ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } }, + }; } else { return config; } @@ -177,9 +169,10 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn // add the security groups to the connections object if (props.vpcConfig) { this.vpc = props.vpcConfig.vpc; - this.subnets = (props.vpcConfig.subnets) ? - (this.vpc.selectSubnets(props.vpcConfig.subnets).subnetIds) : this.vpc.selectSubnets().subnetIds; + this.subnets = props.vpcConfig.subnets ? this.vpc.selectSubnets(props.vpcConfig.subnets).subnetIds : this.vpc.selectSubnets().subnetIds; } + + this.taskPolicies = this.makePolicyStatements(); } /** @@ -211,137 +204,84 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn this.securityGroups.push(securityGroup); } - public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig { - // set the sagemaker role or create new one - this._grantPrincipal = this._role = this.props.role || new iam.Role(task, 'SagemakerRole', { - assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), - inlinePolicies: { - CreateTrainingJob: new iam.PolicyDocument({ - statements: [ - new iam.PolicyStatement({ - actions: [ - 'cloudwatch:PutMetricData', - 'logs:CreateLogStream', - 'logs:PutLogEvents', - 'logs:CreateLogGroup', - 'logs:DescribeLogStreams', - 'ecr:GetAuthorizationToken', - ...this.props.vpcConfig - ? [ - 'ec2:CreateNetworkInterface', - 'ec2:CreateNetworkInterfacePermission', - 'ec2:DeleteNetworkInterface', - 'ec2:DeleteNetworkInterfacePermission', - 'ec2:DescribeNetworkInterfaces', - 'ec2:DescribeVpcs', - 'ec2:DescribeDhcpOptions', - 'ec2:DescribeSubnets', - 'ec2:DescribeSecurityGroups', - ] - : [], - ], - resources: ['*'], // Those permissions cannot be resource-scoped - }), - ], - }), - }, - }); - - if (this.props.outputDataConfig.encryptionKey) { - this.props.outputDataConfig.encryptionKey.grantEncrypt(this._role); - } - - if (this.props.resourceConfig && this.props.resourceConfig.volumeEncryptionKey) { - this.props.resourceConfig.volumeEncryptionKey.grant(this._role, 'kms:CreateGrant'); - } - - // create a security group if not defined - if (this.vpc && this.securityGroup === undefined) { - this.securityGroup = new ec2.SecurityGroup(task, 'TrainJobSecurityGroup', { - vpc: this.vpc, - }); - this.connections.addSecurityGroup(this.securityGroup); - this.securityGroups.push(this.securityGroup); - } - + protected renderTask(): any { return { - resourceArn: getResourceArn('sagemaker', 'createTrainingJob', this.integrationPattern), - parameters: this.renderParameters(), - policyStatements: this.makePolicyStatements(task), + Resource: integrationResourceArn('sagemaker', 'createTrainingJob', this.integrationPattern), + Parameters: sfn.FieldUtils.renderObject(this.renderParameters()), }; } - private renderParameters(): {[key: string]: any} { + private renderParameters(): { [key: string]: any } { return { TrainingJobName: this.props.trainingJobName, RoleArn: this._role!.roleArn, - ...(this.renderAlgorithmSpecification(this.algorithmSpecification)), - ...(this.renderInputDataConfig(this.inputDataConfig)), - ...(this.renderOutputDataConfig(this.props.outputDataConfig)), - ...(this.renderResourceConfig(this.resourceConfig)), - ...(this.renderStoppingCondition(this.stoppingCondition)), - ...(this.renderHyperparameters(this.props.hyperparameters)), - ...(this.renderTags(this.props.tags)), - ...(this.renderVpcConfig(this.props.vpcConfig)), + ...this.renderAlgorithmSpecification(this.algorithmSpecification), + ...this.renderInputDataConfig(this.inputDataConfig), + ...this.renderOutputDataConfig(this.props.outputDataConfig), + ...this.renderResourceConfig(this.resourceConfig), + ...this.renderStoppingCondition(this.stoppingCondition), + ...this.renderHyperparameters(this.props.hyperparameters), + ...this.renderTags(this.props.tags), + ...this.renderVpcConfig(this.props.vpcConfig), }; } - private renderAlgorithmSpecification(spec: AlgorithmSpecification): {[key: string]: any} { + private renderAlgorithmSpecification(spec: AlgorithmSpecification): { [key: string]: any } { return { AlgorithmSpecification: { TrainingInputMode: spec.trainingInputMode, - ...(spec.trainingImage) ? { TrainingImage: spec.trainingImage.bind(this).imageUri } : {}, - ...(spec.algorithmName) ? { AlgorithmName: spec.algorithmName } : {}, - ...(spec.metricDefinitions) ? - { MetricDefinitions: spec.metricDefinitions - .map(metric => ({ Name: metric.name, Regex: metric.regex })) } : {}, + ...(spec.trainingImage ? { TrainingImage: spec.trainingImage.bind(this).imageUri } : {}), + ...(spec.algorithmName ? { AlgorithmName: spec.algorithmName } : {}), + ...(spec.metricDefinitions + ? { MetricDefinitions: spec.metricDefinitions.map((metric) => ({ Name: metric.name, Regex: metric.regex })) } + : {}), }, }; } - private renderInputDataConfig(config: Channel[]): {[key: string]: any} { + private renderInputDataConfig(config: Channel[]): { [key: string]: any } { return { - InputDataConfig: config.map(channel => ({ + InputDataConfig: config.map((channel) => ({ ChannelName: channel.channelName, DataSource: { S3DataSource: { S3Uri: channel.dataSource.s3DataSource.s3Location.bind(this, { forReading: true }).uri, S3DataType: channel.dataSource.s3DataSource.s3DataType, - ...(channel.dataSource.s3DataSource.s3DataDistributionType) ? - { S3DataDistributionType: channel.dataSource.s3DataSource.s3DataDistributionType} : {}, - ...(channel.dataSource.s3DataSource.attributeNames) ? - { AtttributeNames: channel.dataSource.s3DataSource.attributeNames } : {}, + ...(channel.dataSource.s3DataSource.s3DataDistributionType + ? { S3DataDistributionType: channel.dataSource.s3DataSource.s3DataDistributionType } + : {}), + ...(channel.dataSource.s3DataSource.attributeNames ? { AtttributeNames: channel.dataSource.s3DataSource.attributeNames } : {}), }, }, - ...(channel.compressionType) ? { CompressionType: channel.compressionType } : {}, - ...(channel.contentType) ? { ContentType: channel.contentType } : {}, - ...(channel.inputMode) ? { InputMode: channel.inputMode } : {}, - ...(channel.recordWrapperType) ? { RecordWrapperType: channel.recordWrapperType } : {}, + ...(channel.compressionType ? { CompressionType: channel.compressionType } : {}), + ...(channel.contentType ? { ContentType: channel.contentType } : {}), + ...(channel.inputMode ? { InputMode: channel.inputMode } : {}), + ...(channel.recordWrapperType ? { RecordWrapperType: channel.recordWrapperType } : {}), })), }; } - private renderOutputDataConfig(config: OutputDataConfig): {[key: string]: any} { + private renderOutputDataConfig(config: OutputDataConfig): { [key: string]: any } { return { OutputDataConfig: { S3OutputPath: config.s3OutputLocation.bind(this, { forWriting: true }).uri, - ...(config.encryptionKey) ? { KmsKeyId: config.encryptionKey.keyArn } : {}, + ...(config.encryptionKey ? { KmsKeyId: config.encryptionKey.keyArn } : {}), }, }; } - private renderResourceConfig(config: ResourceConfig): {[key: string]: any} { + private renderResourceConfig(config: ResourceConfig): { [key: string]: any } { return { ResourceConfig: { InstanceCount: config.instanceCount, InstanceType: 'ml.' + config.instanceType, - VolumeSizeInGB: config.volumeSizeInGB, - ...(config.volumeEncryptionKey) ? { VolumeKmsKeyId: config.volumeEncryptionKey.keyArn } : {}, + VolumeSizeInGB: config.volumeSize.toGibibytes(), + ...(config.volumeEncryptionKey ? { VolumeKmsKeyId: config.volumeEncryptionKey.keyArn } : {}), }, }; } - private renderStoppingCondition(config: StoppingCondition): {[key: string]: any} { + private renderStoppingCondition(config: StoppingCondition): { [key: string]: any } { return { StoppingCondition: { MaxRuntimeInSeconds: config.maxRuntime && config.maxRuntime.toSeconds(), @@ -349,23 +289,81 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn }; } - private renderHyperparameters(params: {[key: string]: any} | undefined): {[key: string]: any} { - return (params) ? { HyperParameters: params } : {}; + private renderHyperparameters(params: { [key: string]: any } | undefined): { [key: string]: any } { + return params ? { HyperParameters: params } : {}; } - private renderTags(tags: {[key: string]: any} | undefined): {[key: string]: any} { - return (tags) ? { Tags: Object.keys(tags).map(key => ({ Key: key, Value: tags[key] })) } : {}; + private renderTags(tags: { [key: string]: any } | undefined): { [key: string]: any } { + return tags ? { Tags: Object.keys(tags).map((key) => ({ Key: key, Value: tags[key] })) } : {}; } - private renderVpcConfig(config: VpcConfig | undefined): {[key: string]: any} { - return (config) ? { VpcConfig: { - SecurityGroupIds: Lazy.listValue({ produce: () => (this.securityGroups.map(sg => (sg.securityGroupId))) }), - Subnets: this.subnets, - }} : {}; + private renderVpcConfig(config: VpcConfig | undefined): { [key: string]: any } { + return config + ? { + VpcConfig: { + SecurityGroupIds: Lazy.listValue({ produce: () => this.securityGroups.map((sg) => sg.securityGroupId) }), + Subnets: this.subnets, + }, + } + : {}; } - private makePolicyStatements(task: sfn.Task): iam.PolicyStatement[] { - const stack = Stack.of(task); + private makePolicyStatements(): iam.PolicyStatement[] { + // set the sagemaker role or create new one + this._grantPrincipal = this._role = + this.props.role || + new iam.Role(this, 'SagemakerRole', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + inlinePolicies: { + CreateTrainingJob: new iam.PolicyDocument({ + statements: [ + new iam.PolicyStatement({ + actions: [ + 'cloudwatch:PutMetricData', + 'logs:CreateLogStream', + 'logs:PutLogEvents', + 'logs:CreateLogGroup', + 'logs:DescribeLogStreams', + 'ecr:GetAuthorizationToken', + ...(this.props.vpcConfig + ? [ + 'ec2:CreateNetworkInterface', + 'ec2:CreateNetworkInterfacePermission', + 'ec2:DeleteNetworkInterface', + 'ec2:DeleteNetworkInterfacePermission', + 'ec2:DescribeNetworkInterfaces', + 'ec2:DescribeVpcs', + 'ec2:DescribeDhcpOptions', + 'ec2:DescribeSubnets', + 'ec2:DescribeSecurityGroups', + ] + : []), + ], + resources: ['*'], // Those permissions cannot be resource-scoped + }), + ], + }), + }, + }); + + if (this.props.outputDataConfig.encryptionKey) { + this.props.outputDataConfig.encryptionKey.grantEncrypt(this._role); + } + + if (this.props.resourceConfig && this.props.resourceConfig.volumeEncryptionKey) { + this.props.resourceConfig.volumeEncryptionKey.grant(this._role, 'kms:CreateGrant'); + } + + // create a security group if not defined + if (this.vpc && this.securityGroup === undefined) { + this.securityGroup = new ec2.SecurityGroup(this, 'TrainJobSecurityGroup', { + vpc: this.vpc, + }); + this.connections.addSecurityGroup(this.securityGroup); + this.securityGroups.push(this.securityGroup); + } + + const stack = Stack.of(this); // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html const policyStatements = [ @@ -393,15 +391,19 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn }), ]; - if (this.integrationPattern === sfn.ServiceIntegrationPattern.SYNC) { - policyStatements.push(new iam.PolicyStatement({ - actions: ['events:PutTargets', 'events:PutRule', 'events:DescribeRule'], - resources: [stack.formatArn({ - service: 'events', - resource: 'rule', - resourceName: 'StepFunctionsGetEventsForSageMakerTrainingJobsRule', - })], - })); + if (this.integrationPattern === sfn.IntegrationPattern.RUN_JOB) { + policyStatements.push( + new iam.PolicyStatement({ + actions: ['events:PutTargets', 'events:PutRule', 'events:DescribeRule'], + resources: [ + stack.formatArn({ + service: 'events', + resource: 'rule', + resourceName: 'StepFunctionsGetEventsForSageMakerTrainingJobsRule', + }), + ], + }), + ); } return policyStatements; diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-transform-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts similarity index 51% rename from packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-transform-task.ts rename to packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts index 5d4449d052a17..111a15500443e 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-transform-task.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts @@ -1,17 +1,16 @@ import * as ec2 from '@aws-cdk/aws-ec2'; import * as iam from '@aws-cdk/aws-iam'; import * as sfn from '@aws-cdk/aws-stepfunctions'; -import { Stack } from '@aws-cdk/core'; -import { getResourceArn } from '../resource-arn-suffix'; -import { BatchStrategy, S3DataType, TransformInput, TransformOutput, TransformResources } from './sagemaker-task-base-types'; +import { Construct, Size, Stack } from '@aws-cdk/core'; +import { integrationResourceArn, validatePatternSupported } from '../private/task-utils'; +import { BatchStrategy, S3DataType, TransformInput, TransformOutput, TransformResources } from './base-types'; /** * Properties for creating an Amazon SageMaker training job task * * @experimental */ -export interface SagemakerTransformProps { - +export interface SageMakerCreateTransformJobProps extends sfn.TaskStateBaseProps { /** * Training Job Name. */ @@ -24,15 +23,6 @@ export interface SagemakerTransformProps { */ readonly role?: iam.IRole; - /** - * The service integration pattern indicates different ways to call SageMaker APIs. - * - * The valid value is either FIRE_AND_FORGET or SYNC. - * - * @default FIRE_AND_FORGET - */ - readonly integrationPattern?: sfn.ServiceIntegrationPattern; - /** * Number of records to include in a mini-batch for an HTTP inference request. * @@ -45,7 +35,7 @@ export interface SagemakerTransformProps { * * @default - No environment variables */ - readonly environment?: {[key: string]: string}; + readonly environment?: { [key: string]: string }; /** * Maximum number of parallel requests that can be sent to each instance in a transform job. @@ -60,7 +50,7 @@ export interface SagemakerTransformProps { * * @default 6 */ - readonly maxPayloadInMB?: number; + readonly maxPayload?: Size; /** * Name of the model that you want to use for the transform job. @@ -72,7 +62,7 @@ export interface SagemakerTransformProps { * * @default - No tags */ - readonly tags?: {[key: string]: string}; + readonly tags?: { [key: string]: string }; /** * Dataset to be transformed and the Amazon S3 location where it is stored. @@ -97,7 +87,14 @@ export interface SagemakerTransformProps { * * @experimental */ -export class SagemakerTransformTask implements sfn.IStepFunctionsTask { +export class SageMakerCreateTransformJob extends sfn.TaskStateBase { + private static readonly SUPPORTED_INTEGRATION_PATTERNS: sfn.IntegrationPattern[] = [ + sfn.IntegrationPattern.REQUEST_RESPONSE, + sfn.IntegrationPattern.RUN_JOB, + ]; + + protected readonly taskPolicies?: iam.PolicyStatement[]; + protected readonly taskMetrics?: sfn.TaskMetricsConfig; /** * Dataset to be transformed and the Amazon S3 location where it is stored. @@ -108,20 +105,13 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask { * ML compute instances for the transform job. */ private readonly transformResources: TransformResources; - private readonly integrationPattern: sfn.ServiceIntegrationPattern; + private readonly integrationPattern: sfn.IntegrationPattern; private _role?: iam.IRole; - constructor(private readonly props: SagemakerTransformProps) { - this.integrationPattern = props.integrationPattern || sfn.ServiceIntegrationPattern.FIRE_AND_FORGET; - - const supportedPatterns = [ - sfn.ServiceIntegrationPattern.FIRE_AND_FORGET, - sfn.ServiceIntegrationPattern.SYNC, - ]; - - if (!supportedPatterns.includes(this.integrationPattern)) { - throw new Error(`Invalid Service Integration Pattern: ${this.integrationPattern} is not supported to call SageMaker.`); - } + constructor(scope: Construct, id: string, private readonly props: SageMakerCreateTransformJobProps) { + super(scope, id, props); + this.integrationPattern = props.integrationPattern || sfn.IntegrationPattern.REQUEST_RESPONSE; + validatePatternSupported(this.integrationPattern, SageMakerCreateTransformJob.SUPPORTED_INTEGRATION_PATTERNS); // set the sagemaker role or create new one if (props.role) { @@ -129,38 +119,25 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask { } // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined - this.transformInput = (props.transformInput.transformDataSource.s3DataSource.s3DataType) ? (props.transformInput) : - Object.assign({}, props.transformInput, - { transformDataSource: - { s3DataSource: - { ...props.transformInput.transformDataSource.s3DataSource, - s3DataType: S3DataType.S3_PREFIX, - }, - }, - }); + this.transformInput = props.transformInput.transformDataSource.s3DataSource.s3DataType + ? props.transformInput + : Object.assign({}, props.transformInput, { + transformDataSource: { s3DataSource: { ...props.transformInput.transformDataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } }, + }); // set the default value for the transform resources this.transformResources = props.transformResources || { instanceCount: 1, instanceType: ec2.InstanceType.of(ec2.InstanceClass.M4, ec2.InstanceSize.XLARGE), }; - } - public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig { - // create new role if doesn't exist - if (this._role === undefined) { - this._role = new iam.Role(task, 'SagemakerTransformRole', { - assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), - managedPolicies: [ - iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess'), - ], - }); - } + this.taskPolicies = this.makePolicyStatements(); + } + protected renderTask(): any { return { - resourceArn: getResourceArn('sagemaker', 'createTransformJob', this.integrationPattern), - parameters: this.renderParameters(), - policyStatements: this.makePolicyStatements(task), + Resource: integrationResourceArn('sagemaker', 'createTransformJob', this.integrationPattern), + Parameters: sfn.FieldUtils.renderObject(this.renderParameters()), }; } @@ -176,78 +153,88 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask { return this._role; } - private renderParameters(): {[key: string]: any} { + private renderParameters(): { [key: string]: any } { return { - ...(this.props.batchStrategy) ? { BatchStrategy: this.props.batchStrategy } : {}, - ...(this.renderEnvironment(this.props.environment)), - ...(this.props.maxConcurrentTransforms) ? { MaxConcurrentTransforms: this.props.maxConcurrentTransforms } : {}, - ...(this.props.maxPayloadInMB) ? { MaxPayloadInMB: this.props.maxPayloadInMB } : {}, + ...(this.props.batchStrategy ? { BatchStrategy: this.props.batchStrategy } : {}), + ...this.renderEnvironment(this.props.environment), + ...(this.props.maxConcurrentTransforms ? { MaxConcurrentTransforms: this.props.maxConcurrentTransforms } : {}), + ...(this.props.maxPayload ? { MaxPayloadInMB: this.props.maxPayload.toMebibytes() } : {}), ModelName: this.props.modelName, - ...(this.renderTags(this.props.tags)), - ...(this.renderTransformInput(this.transformInput)), + ...this.renderTags(this.props.tags), + ...this.renderTransformInput(this.transformInput), TransformJobName: this.props.transformJobName, - ...(this.renderTransformOutput(this.props.transformOutput)), - ...(this.renderTransformResources(this.transformResources)), + ...this.renderTransformOutput(this.props.transformOutput), + ...this.renderTransformResources(this.transformResources), }; } - private renderTransformInput(input: TransformInput): {[key: string]: any} { + private renderTransformInput(input: TransformInput): { [key: string]: any } { return { TransformInput: { - ...(input.compressionType) ? { CompressionType: input.compressionType } : {}, - ...(input.contentType) ? { ContentType: input.contentType } : {}, + ...(input.compressionType ? { CompressionType: input.compressionType } : {}), + ...(input.contentType ? { ContentType: input.contentType } : {}), DataSource: { S3DataSource: { S3Uri: input.transformDataSource.s3DataSource.s3Uri, S3DataType: input.transformDataSource.s3DataSource.s3DataType, }, }, - ...(input.splitType) ? { SplitType: input.splitType } : {}, + ...(input.splitType ? { SplitType: input.splitType } : {}), }, }; } - private renderTransformOutput(output: TransformOutput): {[key: string]: any} { + private renderTransformOutput(output: TransformOutput): { [key: string]: any } { return { TransformOutput: { S3OutputPath: output.s3OutputPath, - ...(output.encryptionKey) ? { KmsKeyId: output.encryptionKey.keyArn } : {}, - ...(output.accept) ? { Accept: output.accept } : {}, - ...(output.assembleWith) ? { AssembleWith: output.assembleWith } : {}, + ...(output.encryptionKey ? { KmsKeyId: output.encryptionKey.keyArn } : {}), + ...(output.accept ? { Accept: output.accept } : {}), + ...(output.assembleWith ? { AssembleWith: output.assembleWith } : {}), }, }; } - private renderTransformResources(resources: TransformResources): {[key: string]: any} { + private renderTransformResources(resources: TransformResources): { [key: string]: any } { return { TransformResources: { InstanceCount: resources.instanceCount, InstanceType: 'ml.' + resources.instanceType, - ...(resources.volumeKmsKeyId) ? { VolumeKmsKeyId: resources.volumeKmsKeyId.keyArn } : {}, + ...(resources.volumeEncryptionKey ? { VolumeKmsKeyId: resources.volumeEncryptionKey.keyArn } : {}), }, }; } - private renderEnvironment(environment: {[key: string]: any} | undefined): {[key: string]: any} { - return (environment) ? { Environment: environment } : {}; + private renderEnvironment(environment: { [key: string]: any } | undefined): { [key: string]: any } { + return environment ? { Environment: environment } : {}; } - private renderTags(tags: {[key: string]: any} | undefined): {[key: string]: any} { - return (tags) ? { Tags: Object.keys(tags).map(key => ({ Key: key, Value: tags[key] })) } : {}; + private renderTags(tags: { [key: string]: any } | undefined): { [key: string]: any } { + return tags ? { Tags: Object.keys(tags).map((key) => ({ Key: key, Value: tags[key] })) } : {}; } - private makePolicyStatements(task: sfn.Task): iam.PolicyStatement[] { - const stack = Stack.of(task); + private makePolicyStatements(): iam.PolicyStatement[] { + const stack = Stack.of(this); + + // create new role if doesn't exist + if (this._role === undefined) { + this._role = new iam.Role(this, 'SagemakerTransformRole', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicies: [iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess')], + }); + } // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html const policyStatements = [ new iam.PolicyStatement({ actions: ['sagemaker:CreateTransformJob', 'sagemaker:DescribeTransformJob', 'sagemaker:StopTransformJob'], - resources: [stack.formatArn({ - service: 'sagemaker', - resource: 'transform-job', - resourceName: '*', - })], + resources: [ + stack.formatArn({ + service: 'sagemaker', + resource: 'transform-job', + resourceName: '*', + }), + ], }), new iam.PolicyStatement({ actions: ['sagemaker:ListTags'], @@ -262,15 +249,19 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask { }), ]; - if (this.integrationPattern === sfn.ServiceIntegrationPattern.SYNC) { - policyStatements.push(new iam.PolicyStatement({ - actions: ['events:PutTargets', 'events:PutRule', 'events:DescribeRule'], - resources: [stack.formatArn({ - service: 'events', - resource: 'rule', - resourceName: 'StepFunctionsGetEventsForSageMakerTransformJobsRule', - }) ], - })); + if (this.integrationPattern === sfn.IntegrationPattern.RUN_JOB) { + policyStatements.push( + new iam.PolicyStatement({ + actions: ['events:PutTargets', 'events:PutRule', 'events:DescribeRule'], + resources: [ + stack.formatArn({ + service: 'events', + resource: 'rule', + resourceName: 'StepFunctionsGetEventsForSageMakerTransformJobsRule', + }), + ], + }), + ); } return policyStatements; diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/sagemaker-training-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts similarity index 92% rename from packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/sagemaker-training-job.test.ts rename to packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts index 58b7d314b535d..4f02f9ac048a1 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/sagemaker-training-job.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts @@ -6,6 +6,7 @@ import * as s3 from '@aws-cdk/aws-s3'; import * as sfn from '@aws-cdk/aws-stepfunctions'; import * as cdk from '@aws-cdk/core'; import * as tasks from '../../lib'; +import { SageMakerCreateTrainingJob } from '../../lib/sagemaker/create-training-job'; let stack: cdk.Stack; @@ -16,7 +17,7 @@ beforeEach(() => { test('create basic training job', () => { // WHEN - const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask({ + const task = new SageMakerCreateTrainingJob(stack, 'TrainSagemaker', { trainingJobName: 'MyTrainJob', algorithmSpecification: { algorithmName: 'BlazingText', @@ -34,7 +35,7 @@ test('create basic training job', () => { outputDataConfig: { s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'OutputBucket', 'mybucket'), 'myoutputpath'), }, - })}); + }); // THEN expect(stack.resolve(task.toStateJson())).toEqual({ @@ -91,8 +92,8 @@ test('create basic training job', () => { test('Task throws if WAIT_FOR_TASK_TOKEN is supplied as service integration pattern', () => { expect(() => { - new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask({ - integrationPattern: sfn.ServiceIntegrationPattern.WAIT_FOR_TASK_TOKEN, + new SageMakerCreateTrainingJob(stack, 'TrainSagemaker', { + integrationPattern: sfn.IntegrationPattern.WAIT_FOR_TASK_TOKEN, trainingJobName: 'MyTrainJob', algorithmSpecification: { algorithmName: 'BlazingText', @@ -110,8 +111,8 @@ test('Task throws if WAIT_FOR_TASK_TOKEN is supplied as service integration patt outputDataConfig: { s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'OutputBucket', 'mybucket'), 'myoutputpath'), }, - })}); - }).toThrow(/Invalid Service Integration Pattern: WAIT_FOR_TASK_TOKEN is not supported to call SageMaker./i); + }); + }).toThrow(/Unsupported service integration pattern. Supported Patterns: REQUEST_RESPONSE,RUN_JOB. Received: WAIT_FOR_TASK_TOKEN/i); }); test('create complex training job', () => { @@ -128,9 +129,9 @@ test('create complex training job', () => { ], }); - const trainTask = new tasks.SagemakerTrainTask({ + const trainTask = new SageMakerCreateTrainingJob(stack, 'TrainSagemaker', { trainingJobName: 'MyTrainJob', - integrationPattern: sfn.ServiceIntegrationPattern.SYNC, + integrationPattern: sfn.IntegrationPattern.RUN_JOB, role, algorithmSpecification: { algorithmName: 'BlazingText', @@ -177,7 +178,7 @@ test('create complex training job', () => { resourceConfig: { instanceCount: 1, instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), - volumeSizeInGB: 50, + volumeSize: cdk.Size.gibibytes(50), volumeEncryptionKey: kmsKey, }, stoppingCondition: { @@ -191,10 +192,9 @@ test('create complex training job', () => { }, }); trainTask.addSecurityGroup(securityGroup); - const task = new sfn.Task(stack, 'TrainSagemaker', { task: trainTask }); // THEN - expect(stack.resolve(task.toStateJson())).toEqual({ + expect(stack.resolve(trainTask.toStateJson())).toEqual({ Type: 'Task', Resource: { 'Fn::Join': [ @@ -272,8 +272,8 @@ test('create complex training job', () => { ], VpcConfig: { SecurityGroupIds: [ - { 'Fn::GetAtt': [ 'SecurityGroupDD263621', 'GroupId' ] }, { 'Fn::GetAtt': [ 'TrainSagemakerTrainJobSecurityGroup7C858EB9', 'GroupId' ] }, + { 'Fn::GetAtt': [ 'SecurityGroupDD263621', 'GroupId' ] }, ], Subnets: [ { Ref: 'VPCPrivateSubnet1Subnet8BCA10E0' }, @@ -293,7 +293,7 @@ test('pass param to training job', () => { ], }); - const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask({ + const task = new SageMakerCreateTrainingJob(stack, 'TrainSagemaker', { trainingJobName: sfn.Data.stringAt('$.JobName'), role, algorithmSpecification: { @@ -317,12 +317,12 @@ test('pass param to training job', () => { resourceConfig: { instanceCount: 1, instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), - volumeSizeInGB: 50, + volumeSize: cdk.Size.gibibytes(50), }, stoppingCondition: { maxRuntime: cdk.Duration.hours(1), }, - })}); + }); // THEN expect(stack.resolve(task.toStateJson())).toEqual({ @@ -377,7 +377,7 @@ test('pass param to training job', () => { test('Cannot create a SageMaker train task with both algorithm name and image name missing', () => { - expect(() => new tasks.SagemakerTrainTask({ + expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', { trainingJobName: 'myTrainJob', algorithmSpecification: {}, inputDataConfig: [ diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/sagemaker-transform-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-transform-job.test.ts similarity index 88% rename from packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/sagemaker-transform-job.test.ts rename to packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-transform-job.test.ts index c08a28bb0c973..c53233523cfa7 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/sagemaker-transform-job.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-transform-job.test.ts @@ -5,6 +5,7 @@ import * as kms from '@aws-cdk/aws-kms'; import * as sfn from '@aws-cdk/aws-stepfunctions'; import * as cdk from '@aws-cdk/core'; import * as tasks from '../../lib'; +import { SageMakerCreateTransformJob } from '../../lib/sagemaker/create-transform-job'; let stack: cdk.Stack; let role: iam.Role; @@ -22,7 +23,7 @@ beforeEach(() => { test('create basic transform job', () => { // WHEN - const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ + const task = new SageMakerCreateTransformJob(stack, 'TransformTask', { transformJobName: 'MyTransformJob', modelName: 'MyModelName', transformInput: { @@ -35,7 +36,7 @@ test('create basic transform job', () => { transformOutput: { s3OutputPath: 's3://outputbucket/prefix', }, - }) }); + }); // THEN expect(stack.resolve(task.toStateJson())).toEqual({ @@ -77,8 +78,8 @@ test('create basic transform job', () => { test('Task throws if WAIT_FOR_TASK_TOKEN is supplied as service integration pattern', () => { expect(() => { - new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ - integrationPattern: sfn.ServiceIntegrationPattern.WAIT_FOR_TASK_TOKEN, + new SageMakerCreateTransformJob(stack, 'TransformTask', { + integrationPattern: sfn.IntegrationPattern.WAIT_FOR_TASK_TOKEN, transformJobName: 'MyTransformJob', modelName: 'MyModelName', transformInput: { @@ -91,17 +92,17 @@ test('Task throws if WAIT_FOR_TASK_TOKEN is supplied as service integration patt transformOutput: { s3OutputPath: 's3://outputbucket/prefix', }, - }) }); - }).toThrow(/Invalid Service Integration Pattern: WAIT_FOR_TASK_TOKEN is not supported to call SageMaker./i); + }); + }).toThrow(/Unsupported service integration pattern. Supported Patterns: REQUEST_RESPONSE,RUN_JOB. Received: WAIT_FOR_TASK_TOKEN/); }); test('create complex transform job', () => { // WHEN const kmsKey = new kms.Key(stack, 'Key'); - const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ + const task = new SageMakerCreateTransformJob(stack, 'TransformTask', { transformJobName: 'MyTransformJob', modelName: 'MyModelName', - integrationPattern: sfn.ServiceIntegrationPattern.SYNC, + integrationPattern: sfn.IntegrationPattern.RUN_JOB, role, transformInput: { transformDataSource: { @@ -118,7 +119,7 @@ test('create complex transform job', () => { transformResources: { instanceCount: 1, instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), - volumeKmsKeyId: kmsKey, + volumeEncryptionKey: kmsKey, }, tags: { Project: 'MyProject', @@ -128,8 +129,8 @@ test('create complex transform job', () => { SOMEVAR: 'myvalue', }, maxConcurrentTransforms: 3, - maxPayloadInMB: 100, - }) }); + maxPayload: cdk.Size.mebibytes(100), + }); // THEN expect(stack.resolve(task.toStateJson())).toEqual({ @@ -182,7 +183,7 @@ test('create complex transform job', () => { test('pass param to transform job', () => { // WHEN - const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ + const task = new SageMakerCreateTransformJob(stack, 'TransformTask', { transformJobName: sfn.Data.stringAt('$.TransformJobName'), modelName: sfn.Data.stringAt('$.ModelName'), role, @@ -201,7 +202,7 @@ test('pass param to transform job', () => { instanceCount: 1, instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), }, - }) }); + }); // THEN expect(stack.resolve(task.toStateJson())).toEqual({ diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.sagemaker.expected.json b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.expected.json similarity index 94% rename from packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.sagemaker.expected.json rename to packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.expected.json index 52aeac4dc5de3..cf95e9f59a16e 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.sagemaker.expected.json +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.expected.json @@ -304,7 +304,7 @@ { "Ref": "AWS::AccountId" }, - ":training-job/MyTrainingJob*" + ":training-job/mytrainingjob*" ] ] } @@ -343,18 +343,28 @@ "StateMachine2E01A3A5": { "Type": "AWS::StepFunctions::StateMachine", "Properties": { + "RoleArn": { + "Fn::GetAtt": [ + "StateMachineRoleB840431D", + "Arn" + ] + }, "DefinitionString": { "Fn::Join": [ "", [ - "{\"StartAt\":\"TrainTask\",\"States\":{\"TrainTask\":{\"End\":true,\"Parameters\":{\"TrainingJobName\":\"MyTrainingJob\",\"RoleArn\":\"", + "{\"StartAt\":\"TrainTask\",\"States\":{\"TrainTask\":{\"End\":true,\"Type\":\"Task\",\"Resource\":\"arn:", + { + "Ref": "AWS::Partition" + }, + ":states:::sagemaker:createTrainingJob\",\"Parameters\":{\"TrainingJobName\":\"mytrainingjob\",\"RoleArn\":\"", { "Fn::GetAtt": [ "TrainTaskSagemakerRole0A9B1CDD", "Arn" ] }, - "\",\"AlgorithmSpecification\":{\"TrainingInputMode\":\"File\",\"AlgorithmName\":\"GRADIENT_ASCENT\"},\"InputDataConfig\":[{\"ChannelName\":\"InputData\",\"DataSource\":{\"S3DataSource\":{\"S3Uri\":\"https://s3.", + "\",\"AlgorithmSpecification\":{\"TrainingInputMode\":\"File\",\"AlgorithmName\":\"arn:aws:sagemaker:us-east-1:865070037744:algorithm/scikit-decision-trees-15423055-57b73412d2e93e9239e4e16f83298b8f\"},\"InputDataConfig\":[{\"ChannelName\":\"InputData\",\"DataSource\":{\"S3DataSource\":{\"S3Uri\":\"https://s3.", { "Ref": "AWS::Region" }, @@ -378,19 +388,9 @@ { "Ref": "TrainingData3FDB6D34" }, - "/result/\"},\"ResourceConfig\":{\"InstanceCount\":1,\"InstanceType\":\"ml.m4.xlarge\",\"VolumeSizeInGB\":10},\"StoppingCondition\":{\"MaxRuntimeInSeconds\":3600}},\"Type\":\"Task\",\"Resource\":\"arn:", - { - "Ref": "AWS::Partition" - }, - ":states:::sagemaker:createTrainingJob\"}}}" + "/result/\"},\"ResourceConfig\":{\"InstanceCount\":1,\"InstanceType\":\"ml.m4.xlarge\",\"VolumeSizeInGB\":10},\"StoppingCondition\":{\"MaxRuntimeInSeconds\":3600}}}}}" ] ] - }, - "RoleArn": { - "Fn::GetAtt": [ - "StateMachineRoleB840431D", - "Arn" - ] } }, "DependsOn": [ @@ -398,5 +398,12 @@ "StateMachineRoleB840431D" ] } + }, + "Outputs": { + "stateMachineArn": { + "Value": { + "Ref": "StateMachine2E01A3A5" + } + } } } \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.ts new file mode 100644 index 0000000000000..28e4e65ff0e1e --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.ts @@ -0,0 +1,53 @@ +import { Key } from '@aws-cdk/aws-kms'; +import { Bucket, BucketEncryption } from '@aws-cdk/aws-s3'; +import { StateMachine } from '@aws-cdk/aws-stepfunctions'; +import { App, CfnOutput, RemovalPolicy, Stack } from '@aws-cdk/core'; +import { S3Location } from '../../lib'; +import { SageMakerCreateTrainingJob } from '../../lib/sagemaker/create-training-job'; + +/* + * Creates a state machine with a task state to create a training job in AWS SageMaker + * SageMaker jobs need training algorithms. These can be found in the AWS marketplace + * or created. + * + * Subscribe to demo Algorithm vended by Amazon (free): + * https://aws.amazon.com/marketplace/ai/procurement?productId=cc5186a0-b8d6-4750-a9bb-1dcdf10e787a + * FIXME - create Input data pertinent for the training model and insert into S3 location specified in inputDataConfig. + * + * Stack verification steps: + * The generated State Machine can be executed from the CLI (or Step Functions console) + * and runs with an execution status of `Succeeded`. + * + * -- aws stepfunctions start-execution --state-machine-arn provides execution arn + * -- aws stepfunctions describe-execution --execution-arn returns a status of `Succeeded` + */ +const app = new App(); +const stack = new Stack(app, 'integ-stepfunctions-sagemaker'); + +const encryptionKey = new Key(stack, 'EncryptionKey', { + removalPolicy: RemovalPolicy.DESTROY, +}); +const trainingData = new Bucket(stack, 'TrainingData', { + encryption: BucketEncryption.KMS, + encryptionKey, + removalPolicy: RemovalPolicy.DESTROY, +}); + +const sm = new StateMachine(stack, 'StateMachine', { + definition: new SageMakerCreateTrainingJob(stack, 'TrainTask', { + algorithmSpecification: { + algorithmName: 'arn:aws:sagemaker:us-east-1:865070037744:algorithm/scikit-decision-trees-15423055-57b73412d2e93e9239e4e16f83298b8f', + }, + inputDataConfig: [{ channelName: 'InputData', dataSource: { + s3DataSource: { + s3Location: S3Location.fromBucket(trainingData, 'data/'), + }, + } }], + outputDataConfig: { s3OutputLocation: S3Location.fromBucket(trainingData, 'result/') }, + trainingJobName: 'mytrainingjob', + }), +}); + +new CfnOutput(stack, 'stateMachineArn', { + value: sm.stateMachineArn, +}); diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.sagemaker.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.sagemaker.ts deleted file mode 100644 index 661f1f1bbd006..0000000000000 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.sagemaker.ts +++ /dev/null @@ -1,34 +0,0 @@ -import { Key } from '@aws-cdk/aws-kms'; -import { Bucket, BucketEncryption } from '@aws-cdk/aws-s3'; -import { StateMachine, Task } from '@aws-cdk/aws-stepfunctions'; -import { App, RemovalPolicy, Stack } from '@aws-cdk/core'; -import { S3Location, SagemakerTrainTask } from '../../lib'; - -const app = new App(); -const stack = new Stack(app, 'integ-stepfunctions-sagemaker'); - -const encryptionKey = new Key(stack, 'EncryptionKey', { - removalPolicy: RemovalPolicy.DESTROY, -}); -const trainingData = new Bucket(stack, 'TrainingData', { - encryption: BucketEncryption.KMS, - encryptionKey, - removalPolicy: RemovalPolicy.DESTROY, -}); - -new StateMachine(stack, 'StateMachine', { - definition: new Task(stack, 'TrainTask', { - task: new SagemakerTrainTask({ - algorithmSpecification: { - algorithmName: 'GRADIENT_ASCENT', - }, - inputDataConfig: [{ channelName: 'InputData', dataSource: { - s3DataSource: { - s3Location: S3Location.fromBucket(trainingData, 'data/'), - }, - } }], - outputDataConfig: { s3OutputLocation: S3Location.fromBucket(trainingData, 'result/') }, - trainingJobName: 'MyTrainingJob', - }), - }), -});