diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts index 7f07ec26f2259..b120dd0e4351d 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts @@ -1,6 +1,13 @@ import ec2 = require('@aws-cdk/aws-ec2'); +import ecr = require('@aws-cdk/aws-ecr'); +import { DockerImageAsset, DockerImageAssetProps } from '@aws-cdk/aws-ecr-assets'; +import iam = require('@aws-cdk/aws-iam'); import kms = require('@aws-cdk/aws-kms'); -import { Duration } from '@aws-cdk/core'; +import s3 = require('@aws-cdk/aws-s3'); +import sfn = require('@aws-cdk/aws-stepfunctions'); +import { Construct, Duration } from '@aws-cdk/core'; + +export interface ISageMakerTask extends sfn.IStepFunctionsTask, iam.IGrantable {} // // Create Training Job types @@ -24,7 +31,7 @@ export interface AlgorithmSpecification { /** * Registry path of the Docker image that contains the training algorithm. */ - readonly trainingImage?: string; + readonly trainingImage?: DockerImage; /** * Input mode that the algorithm supports. @@ -125,7 +132,7 @@ export interface S3DataSource { /** * S3 Uri */ - readonly s3Uri: string; + readonly s3Location: S3Location; } /** @@ -140,9 +147,12 @@ export interface OutputDataConfig { /** * Identifies the S3 path where you want Amazon SageMaker to store the model artifacts. */ - readonly s3OutputPath: string; + readonly s3OutputLocation: S3Location; } +/** + * @experimental + */ export interface StoppingCondition { /** * The maximum length of time, in seconds, that the training or compilation job can run. @@ -150,6 +160,9 @@ export interface StoppingCondition { readonly maxRuntime?: Duration; } +/** + * @experimental + */ export interface ResourceConfig { /** @@ -169,7 +182,7 @@ export interface ResourceConfig { /** * KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) that run the training job. */ - readonly volumeKmsKeyId?: kms.IKey; + readonly volumeEncryptionKey?: kms.IKey; /** * Size of the ML storage volume that you want to provision. @@ -218,8 +231,139 @@ export interface MetricDefinition { readonly regex: string; } +/** + * @experimental + */ +export interface S3LocationConfig { + readonly uri: string; +} + +/** + * Constructs `IS3Location` objects. + * + * @experimental + */ +export abstract class S3Location { + /** + * An `IS3Location` built with a determined bucket and key prefix. + * + * @param bucket is the bucket where the objects are to be stored. + * @param keyPrefix is the key prefix used by the location. + */ + public static fromBucket(bucket: s3.IBucket, keyPrefix: string): S3Location { + return new StandardS3Location({ bucket, keyPrefix, uri: bucket.urlForObject(keyPrefix) }); + } + + /** + * An `IS3Location` determined fully by a JSON Path from the task input. + * + * Due to the dynamic nature of those locations, the IAM grants that will be set by `grantRead` and `grantWrite` + * apply to the `*` resource. + * + * @param expression the JSON expression resolving to an S3 location URI. + */ + public static fromJsonExpression(expression: string): S3Location { + return new StandardS3Location({ uri: sfn.Data.stringAt(expression) }); + } + + /** + * Called when the S3Location is bound to a StepFunctions task. + */ + public abstract bind(task: ISageMakerTask, opts: S3LocationBindOptions): S3LocationConfig; +} + +/** + * Options for binding an S3 Location. + * + * @experimental + */ +export interface S3LocationBindOptions { + /** + * Allow reading from the S3 Location. + * + * @default false + */ + readonly forReading?: boolean; + + /** + * Allow writing to the S3 Location. + * + * @default false + */ + readonly forWriting?: boolean; +} + +/** + * Configuration for a using Docker image. + * + * @experimental + */ +export interface DockerImageConfig { + /** + * The fully qualified URI of the Docker image. + */ + readonly imageUri: string; +} + +/** + * Creates `IDockerImage` instances. + * + * @experimental + */ +export abstract class DockerImage { + /** + * Reference a Docker image stored in an ECR repository. + * + * @param repository the ECR repository where the image is hosted. + * @param tag an optional `tag` + */ + public static fromEcrRepository(repository: ecr.IRepository, tag: string = 'latest'): DockerImage { + return new StandardDockerImage({ repository, imageUri: repository.repositoryUriForTag(tag) }); + } + + /** + * Reference a Docker image which URI is obtained from the task's input. + * + * @param expression the JSON path expression with the task input. + * @param allowAnyEcrImagePull whether ECR access should be permitted (set to `false` if the image will never be in ECR). + */ + public static fromJsonExpression(expression: string, allowAnyEcrImagePull = true): DockerImage { + return new StandardDockerImage({ imageUri: expression, allowAnyEcrImagePull }); + } + + /** + * Reference a Docker image by it's URI. + * + * When referencing ECR images, prefer using `inEcr`. + * + * @param imageUri the URI to the docker image. + */ + public static fromRegistry(imageUri: string): DockerImage { + return new StandardDockerImage({ imageUri }); + } + + /** + * Reference a Docker image that is provided as an Asset in the current app. + * + * @param scope the scope in which to create the Asset. + * @param id the ID for the asset in the construct tree. + * @param props the configuration props of the asset. + */ + public static fromAsset(scope: Construct, id: string, props: DockerImageAssetProps): DockerImage { + const asset = new DockerImageAsset(scope, id, props); + return new StandardDockerImage({ repository: asset.repository, imageUri: asset.imageUri }); + } + + /** + * Called when the image is used by a SageMaker task. + */ + public abstract bind(task: ISageMakerTask): DockerImageConfig; +} + /** * S3 Data Type. + * + * @experimental */ export enum S3DataType { /** @@ -240,6 +384,8 @@ export enum S3DataType { /** * S3 Data Distribution Type. + * + * @experimental */ export enum S3DataDistributionType { /** @@ -255,6 +401,8 @@ export enum S3DataDistributionType { /** * Define the format of the input data. + * + * @experimental */ export enum RecordWrapperType { /** @@ -270,6 +418,8 @@ export enum RecordWrapperType { /** * Input mode that the algorithm supports. + * + * @experimental */ export enum InputMode { /** @@ -285,6 +435,8 @@ export enum InputMode { /** * Compression type of the data. + * + * @experimental */ export enum CompressionType { /** @@ -416,6 +568,8 @@ export interface TransformResources { /** * Specifies the number of records to include in a mini-batch for an HTTP inference request. + * + * @experimental */ export enum BatchStrategy { @@ -432,6 +586,8 @@ export enum BatchStrategy { /** * Method to use to split the transform job's data files into smaller batches. + * + * @experimental */ export enum SplitType { @@ -458,6 +614,8 @@ export enum SplitType { /** * How to assemble the results of the transform job as a single S3 object. + * + * @experimental */ export enum AssembleWith { @@ -472,3 +630,70 @@ export enum AssembleWith { LINE = 'Line' } + +class StandardDockerImage extends DockerImage { + private readonly allowAnyEcrImagePull: boolean; + private readonly imageUri: string; + private readonly repository?: ecr.IRepository; + + constructor(opts: { allowAnyEcrImagePull?: boolean, imageUri: string, repository?: ecr.IRepository }) { + super(); + + this.allowAnyEcrImagePull = !!opts.allowAnyEcrImagePull; + this.imageUri = opts.imageUri; + this.repository = opts.repository; + } + + public bind(task: ISageMakerTask): DockerImageConfig { + if (this.repository) { + this.repository.grantPull(task); + } + if (this.allowAnyEcrImagePull) { + task.grantPrincipal.addToPolicy(new iam.PolicyStatement({ + actions: [ + 'ecr:BatchCheckLayerAvailability', + 'ecr:GetDownloadUrlForLayer', + 'ecr:BatchGetImage', + ], + resources: ['*'] + })); + } + return { + imageUri: this.imageUri, + }; + } +} + +class StandardS3Location extends S3Location { + private readonly bucket?: s3.IBucket; + private readonly keyGlob: string; + private readonly uri: string; + + constructor(opts: { bucket?: s3.IBucket, keyPrefix?: string, uri: string }) { + super(); + this.bucket = opts.bucket; + this.keyGlob = `${opts.keyPrefix || ''}*`; + this.uri = opts.uri; + } + + public bind(task: ISageMakerTask, opts: S3LocationBindOptions): S3LocationConfig { + if (this.bucket) { + if (opts.forReading) { + this.bucket.grantRead(task, this.keyGlob); + } + if (opts.forWriting) { + this.bucket.grantWrite(task, this.keyGlob); + } + } else { + const actions = new Array(); + if (opts.forReading) { + actions.push('s3:GetObject', 's3:ListBucket'); + } + if (opts.forWriting) { + actions.push('s3:PutObject'); + } + task.grantPrincipal.addToPolicy(new iam.PolicyStatement({ actions, resources: ['*'], })); + } + return { uri: this.uri }; + } +} diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts index 9cdb1af7c9e9b..61807c1adcab5 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts @@ -8,7 +8,7 @@ import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceC /** * @experimental */ -export interface SagemakerTrainProps { +export interface SagemakerTrainTaskProps { /** * Training Job Name. @@ -16,7 +16,12 @@ export interface SagemakerTrainProps { readonly trainingJobName: string; /** - * Role for thte Training Job. + * Role for the Training Job. The role must be granted all necessary permissions for the SageMaker training job to + * be able to operate. + * + * See https://docs.aws.amazon.com/fr_fr/sagemaker/latest/dg/sagemaker-roles.html#sagemaker-roles-createtrainingjob-perms + * + * @default - a role with appropriate permissions will be created. */ readonly role?: iam.IRole; @@ -70,8 +75,10 @@ export interface SagemakerTrainProps { /** * Class representing the SageMaker Create Training Job task. + * + * @experimental */ -export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsTask { +export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn.IStepFunctionsTask { /** * Allows specify security group connections for instances of this fleet. @@ -85,6 +92,8 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT */ public readonly role: iam.IRole; + public readonly grantPrincipal: iam.IPrincipal; + /** * The Algorithm Specification */ @@ -105,7 +114,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT */ private readonly stoppingCondition: StoppingCondition; - constructor(scope: Construct, private readonly props: SagemakerTrainProps) { + constructor(scope: Construct, private readonly props: SagemakerTrainTaskProps) { // set the default resource config if not defined. this.resourceConfig = props.resourceConfig || { @@ -120,13 +129,48 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT }; // set the sagemaker role or create new one - this.role = props.role || new iam.Role(scope, 'SagemakerRole', { + this.grantPrincipal = this.role = props.role || new iam.Role(scope, 'SagemakerRole', { assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), - managedPolicies: [ - iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess') - ] + inlinePolicies: { + CreateTrainingJob: new iam.PolicyDocument({ + statements: [ + new iam.PolicyStatement({ + actions: [ + 'cloudwatch:PutMetricData', + 'logs:CreateLogStream', + 'logs:PutLogEvents', + 'logs:CreateLogGroup', + 'logs:DescribeLogStreams', + 'ecr:GetAuthorizationToken', + ...props.vpcConfig + ? [ + 'ec2:CreateNetworkInterface', + 'ec2:CreateNetworkInterfacePermission', + 'ec2:DeleteNetworkInterface', + 'ec2:DeleteNetworkInterfacePermission', + 'ec2:DescribeNetworkInterfaces', + 'ec2:DescribeVpcs', + 'ec2:DescribeDhcpOptions', + 'ec2:DescribeSubnets', + 'ec2:DescribeSecurityGroups', + ] + : [], + ], + resources: ['*'], // Those permissions cannot be resource-scoped + }) + ] + }), + } }); + if (props.outputDataConfig.encryptionKey) { + props.outputDataConfig.encryptionKey.grantEncrypt(this.role); + } + + if (props.resourceConfig && props.resourceConfig.volumeEncryptionKey) { + props.resourceConfig.volumeEncryptionKey.grant(this.role, 'kms:CreateGrant'); + } + // set the input mode to 'File' if not defined this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ? ( props.algorithmSpecification ) : @@ -175,7 +219,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT return { AlgorithmSpecification: { TrainingInputMode: spec.trainingInputMode, - ...(spec.trainingImage) ? { TrainingImage: spec.trainingImage } : {}, + ...(spec.trainingImage) ? { TrainingImage: spec.trainingImage.bind(this).imageUri } : {}, ...(spec.algorithmName) ? { AlgorithmName: spec.algorithmName } : {}, ...(spec.metricDefinitions) ? { MetricDefinitions: spec.metricDefinitions @@ -190,7 +234,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT ChannelName: channel.channelName, DataSource: { S3DataSource: { - S3Uri: channel.dataSource.s3DataSource.s3Uri, + S3Uri: channel.dataSource.s3DataSource.s3Location.bind(this, { forReading: true }).uri, S3DataType: channel.dataSource.s3DataSource.s3DataType, ...(channel.dataSource.s3DataSource.s3DataDistributionType) ? { S3DataDistributionType: channel.dataSource.s3DataSource.s3DataDistributionType} : {}, @@ -209,7 +253,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT private renderOutputDataConfig(config: OutputDataConfig): {[key: string]: any} { return { OutputDataConfig: { - S3OutputPath: config.s3OutputPath, + S3OutputPath: config.s3OutputLocation.bind(this, { forWriting: true }).uri, ...(config.encryptionKey) ? { KmsKeyId: config.encryptionKey.keyArn } : {}, } }; @@ -221,7 +265,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT InstanceCount: config.instanceCount, InstanceType: 'ml.' + config.instanceType, VolumeSizeInGB: config.volumeSizeInGB, - ...(config.volumeKmsKeyId) ? { VolumeKmsKeyId: config.volumeKmsKeyId.keyArn } : {}, + ...(config.volumeEncryptionKey) ? { VolumeKmsKeyId: config.volumeEncryptionKey.keyArn } : {}, } }; } @@ -260,7 +304,8 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT stack.formatArn({ service: 'sagemaker', resource: 'training-job', - resourceName: '*' + // If the job name comes from input, we cannot target the policy to a particular ARN prefix reliably... + resourceName: sfn.Data.isJsonPathString(this.props.trainingJobName) ? '*' : `${this.props.trainingJobName}*` }) ], }), diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/package.json b/packages/@aws-cdk/aws-stepfunctions-tasks/package.json index 83e6796fc494f..c770746e3de00 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/package.json +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/package.json @@ -80,25 +80,33 @@ "pkglint": "^0.36.2" }, "dependencies": { + "@aws-cdk/assets": "^0.36.2", "@aws-cdk/aws-cloudwatch": "^0.36.2", "@aws-cdk/aws-ec2": "^0.36.2", + "@aws-cdk/aws-ecr": "^0.36.2", + "@aws-cdk/aws-ecr-assets": "^0.36.2", "@aws-cdk/aws-ecs": "^0.36.2", "@aws-cdk/aws-iam": "^0.36.2", "@aws-cdk/aws-kms": "^0.36.2", "@aws-cdk/aws-lambda": "^0.36.2", "@aws-cdk/aws-sns": "^0.36.2", + "@aws-cdk/aws-s3": "^0.36.2", "@aws-cdk/aws-sqs": "^0.36.2", "@aws-cdk/aws-stepfunctions": "^0.36.2", "@aws-cdk/core": "^0.36.2" }, "homepage": "https://github.com/awslabs/aws-cdk", "peerDependencies": { + "@aws-cdk/assets": "^0.36.2", "@aws-cdk/aws-cloudwatch": "^0.36.2", "@aws-cdk/aws-ec2": "^0.36.2", + "@aws-cdk/aws-ecr": "^0.36.2", + "@aws-cdk/aws-ecr-assets": "^0.36.2", "@aws-cdk/aws-ecs": "^0.36.2", "@aws-cdk/aws-iam": "^0.36.2", "@aws-cdk/aws-kms": "^0.36.2", "@aws-cdk/aws-lambda": "^0.36.2", + "@aws-cdk/aws-s3": "^0.36.2", "@aws-cdk/aws-sns": "^0.36.2", "@aws-cdk/aws-sqs": "^0.36.2", "@aws-cdk/aws-stepfunctions": "^0.36.2", diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/integ.sagemaker.expected.json b/packages/@aws-cdk/aws-stepfunctions-tasks/test/integ.sagemaker.expected.json new file mode 100644 index 0000000000000..55c8ee820af8b --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/integ.sagemaker.expected.json @@ -0,0 +1,399 @@ +{ + "Resources": { + "EncryptionKey1B843E66": { + "Type": "AWS::KMS::Key", + "Properties": { + "KeyPolicy": { + "Statement": [ + { + "Action": [ + "kms:Create*", + "kms:Describe*", + "kms:Enable*", + "kms:List*", + "kms:Put*", + "kms:Update*", + "kms:Revoke*", + "kms:Disable*", + "kms:Get*", + "kms:Delete*", + "kms:ScheduleKeyDeletion", + "kms:CancelKeyDeletion" + ], + "Effect": "Allow", + "Principal": { + "AWS": { + "Fn::Join": [ + "", + [ + "arn:", + { + "Ref": "AWS::Partition" + }, + ":iam::", + { + "Ref": "AWS::AccountId" + }, + ":root" + ] + ] + } + }, + "Resource": "*" + }, + { + "Action": [ + "kms:Decrypt", + "kms:DescribeKey" + ], + "Effect": "Allow", + "Principal": { + "AWS": { + "Fn::GetAtt": [ + "SagemakerRole5FDB64E1", + "Arn" + ] + } + }, + "Resource": "*" + }, + { + "Action": [ + "kms:Encrypt", + "kms:ReEncrypt*", + "kms:GenerateDataKey*" + ], + "Effect": "Allow", + "Principal": { + "AWS": { + "Fn::GetAtt": [ + "SagemakerRole5FDB64E1", + "Arn" + ] + } + }, + "Resource": "*" + } + ], + "Version": "2012-10-17" + } + }, + "DeletionPolicy": "Delete" + }, + "TrainingData3FDB6D34": { + "Type": "AWS::S3::Bucket", + "Properties": { + "BucketEncryption": { + "ServerSideEncryptionConfiguration": [ + { + "ServerSideEncryptionByDefault": { + "KMSMasterKeyID": { + "Fn::GetAtt": [ + "EncryptionKey1B843E66", + "Arn" + ] + }, + "SSEAlgorithm": "aws:kms" + } + } + ] + } + }, + "DeletionPolicy": "Delete" + }, + "SagemakerRole5FDB64E1": { + "Type": "AWS::IAM::Role", + "Properties": { + "AssumeRolePolicyDocument": { + "Statement": [ + { + "Action": "sts:AssumeRole", + "Effect": "Allow", + "Principal": { + "Service": { + "Fn::Join": [ + "", + [ + "sagemaker.", + { + "Ref": "AWS::URLSuffix" + } + ] + ] + } + } + } + ], + "Version": "2012-10-17" + }, + "Policies": [ + { + "PolicyDocument": { + "Statement": [ + { + "Action": [ + "cloudwatch:PutMetricData", + "logs:CreateLogStream", + "logs:PutLogEvents", + "logs:CreateLogGroup", + "logs:DescribeLogStreams", + "ecr:GetAuthorizationToken" + ], + "Effect": "Allow", + "Resource": "*" + } + ], + "Version": "2012-10-17" + }, + "PolicyName": "CreateTrainingJob" + } + ] + } + }, + "SagemakerRoleDefaultPolicy9DD21C3C": { + "Type": "AWS::IAM::Policy", + "Properties": { + "PolicyDocument": { + "Statement": [ + { + "Action": [ + "s3:GetObject*", + "s3:GetBucket*", + "s3:List*" + ], + "Effect": "Allow", + "Resource": [ + { + "Fn::GetAtt": [ + "TrainingData3FDB6D34", + "Arn" + ] + }, + { + "Fn::Join": [ + "", + [ + { + "Fn::GetAtt": [ + "TrainingData3FDB6D34", + "Arn" + ] + }, + "/data/*" + ] + ] + } + ] + }, + { + "Action": [ + "kms:Decrypt", + "kms:DescribeKey" + ], + "Effect": "Allow", + "Resource": { + "Fn::GetAtt": [ + "EncryptionKey1B843E66", + "Arn" + ] + } + }, + { + "Action": [ + "s3:DeleteObject*", + "s3:PutObject*", + "s3:Abort*" + ], + "Effect": "Allow", + "Resource": [ + { + "Fn::GetAtt": [ + "TrainingData3FDB6D34", + "Arn" + ] + }, + { + "Fn::Join": [ + "", + [ + { + "Fn::GetAtt": [ + "TrainingData3FDB6D34", + "Arn" + ] + }, + "/result/*" + ] + ] + } + ] + }, + { + "Action": [ + "kms:Encrypt", + "kms:ReEncrypt*", + "kms:GenerateDataKey*" + ], + "Effect": "Allow", + "Resource": { + "Fn::GetAtt": [ + "EncryptionKey1B843E66", + "Arn" + ] + } + } + ], + "Version": "2012-10-17" + }, + "PolicyName": "SagemakerRoleDefaultPolicy9DD21C3C", + "Roles": [ + { + "Ref": "SagemakerRole5FDB64E1" + } + ] + } + }, + "StateMachineRoleB840431D": { + "Type": "AWS::IAM::Role", + "Properties": { + "AssumeRolePolicyDocument": { + "Statement": [ + { + "Action": "sts:AssumeRole", + "Effect": "Allow", + "Principal": { + "Service": { + "Fn::Join": [ + "", + [ + "states.", + { + "Ref": "AWS::Region" + }, + ".amazonaws.com" + ] + ] + } + } + } + ], + "Version": "2012-10-17" + } + } + }, + "StateMachineRoleDefaultPolicyDF1E6607": { + "Type": "AWS::IAM::Policy", + "Properties": { + "PolicyDocument": { + "Statement": [ + { + "Action": [ + "sagemaker:CreateTrainingJob", + "sagemaker:DescribeTrainingJob", + "sagemaker:StopTrainingJob" + ], + "Effect": "Allow", + "Resource": { + "Fn::Join": [ + "", + [ + "arn:", + { + "Ref": "AWS::Partition" + }, + ":sagemaker:", + { + "Ref": "AWS::Region" + }, + ":", + { + "Ref": "AWS::AccountId" + }, + ":training-job/MyTrainingJob*" + ] + ] + } + }, + { + "Action": "sagemaker:ListTags", + "Effect": "Allow", + "Resource": "*" + }, + { + "Action": "iam:PassRole", + "Condition": { + "StringEquals": { + "iam:PassedToService": "sagemaker.amazonaws.com" + } + }, + "Effect": "Allow", + "Resource": { + "Fn::GetAtt": [ + "SagemakerRole5FDB64E1", + "Arn" + ] + } + } + ], + "Version": "2012-10-17" + }, + "PolicyName": "StateMachineRoleDefaultPolicyDF1E6607", + "Roles": [ + { + "Ref": "StateMachineRoleB840431D" + } + ] + } + }, + "StateMachine2E01A3A5": { + "Type": "AWS::StepFunctions::StateMachine", + "Properties": { + "DefinitionString": { + "Fn::Join": [ + "", + [ + "{\"StartAt\":\"TrainTask\",\"States\":{\"TrainTask\":{\"End\":true,\"Parameters\":{\"TrainingJobName\":\"MyTrainingJob\",\"RoleArn\":\"", + { + "Fn::GetAtt": [ + "SagemakerRole5FDB64E1", + "Arn" + ] + }, + "\",\"AlgorithmSpecification\":{\"TrainingInputMode\":\"File\"},\"InputDataConfig\":[{\"ChannelName\":\"InputData\",\"DataSource\":{\"S3DataSource\":{\"S3Uri\":\"https://s3.", + { + "Ref": "AWS::Region" + }, + ".", + { + "Ref": "AWS::URLSuffix" + }, + "/", + { + "Ref": "TrainingData3FDB6D34" + }, + "/data/\",\"S3DataType\":\"S3Prefix\"}}}],\"OutputDataConfig\":{\"S3OutputPath\":\"https://s3.", + { + "Ref": "AWS::Region" + }, + ".", + { + "Ref": "AWS::URLSuffix" + }, + "/", + { + "Ref": "TrainingData3FDB6D34" + }, + "/result/\"},\"ResourceConfig\":{\"InstanceCount\":1,\"InstanceType\":\"ml.m4.xlarge\",\"VolumeSizeInGB\":10},\"StoppingCondition\":{\"MaxRuntimeInSeconds\":3600}},\"Type\":\"Task\",\"Resource\":\"arn:aws:states:::sagemaker:createTrainingJob\"}}}" + ] + ] + }, + "RoleArn": { + "Fn::GetAtt": [ + "StateMachineRoleB840431D", + "Arn" + ] + } + } + } + } +} \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/integ.sagemaker.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/integ.sagemaker.ts new file mode 100644 index 0000000000000..8a72022d0f959 --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/integ.sagemaker.ts @@ -0,0 +1,28 @@ +import { Key } from '@aws-cdk/aws-kms'; +import { Bucket, BucketEncryption } from '@aws-cdk/aws-s3'; +import { StateMachine, Task } from '@aws-cdk/aws-stepfunctions'; +import { App, RemovalPolicy, Stack } from '@aws-cdk/core'; +import { S3Location, SagemakerTrainTask } from '../lib'; + +const app = new App(); +const stack = new Stack(app, 'integ-stepfunctions-sagemaker'); + +const encryptionKey = new Key(stack, 'EncryptionKey', { + removalPolicy: RemovalPolicy.DESTROY, +}); +const trainingData = new Bucket(stack, 'TrainingData', { + encryption: BucketEncryption.KMS, + encryptionKey, + removalPolicy: RemovalPolicy.DESTROY, +}); + +new StateMachine(stack, 'StateMachine', { + definition: new Task(stack, 'TrainTask', { + task: new SagemakerTrainTask(stack, { + algorithmSpecification: {}, + inputDataConfig: [{ channelName: 'InputData', dataSource: { s3DataSource: { s3Location: S3Location.fromBucket(trainingData, 'data/') } } }], + outputDataConfig: { s3OutputLocation: S3Location.fromBucket(trainingData, 'result/') }, + trainingJobName: 'MyTrainingJob', + }) + }) +}); diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts index 58c557782c57d..43dff4d3f54a4 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts @@ -2,9 +2,11 @@ import '@aws-cdk/assert/jest'; import ec2 = require('@aws-cdk/aws-ec2'); import iam = require('@aws-cdk/aws-iam'); import kms = require('@aws-cdk/aws-kms'); +import s3 = require('@aws-cdk/aws-s3'); import sfn = require('@aws-cdk/aws-stepfunctions'); import cdk = require('@aws-cdk/core'); import tasks = require('../lib'); +import { S3Location } from '../lib'; let stack: cdk.Stack; @@ -25,13 +27,13 @@ test('create basic training job', () => { channelName: 'train', dataSource: { s3DataSource: { - s3Uri: "s3://mybucket/mytrainpath" + s3Location: S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'InputBucket', 'mybucket'), 'mytrainpath') } } } ], outputDataConfig: { - s3OutputPath: 's3://mybucket/myoutputpath' + s3OutputLocation: S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'OutputBucket', 'mybucket'), 'myoutputpath') }, })}); @@ -51,13 +53,17 @@ test('create basic training job', () => { DataSource: { S3DataSource: { S3DataType: 'S3Prefix', - S3Uri: 's3://mybucket/mytrainpath' + S3Uri: { + 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region'}, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/mytrainpath']] + } } } } ], OutputDataConfig: { - S3OutputPath: 's3://mybucket/myoutputpath' + S3OutputPath: { + 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region' }, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/myoutputpath']] + } }, ResourceConfig: { InstanceCount: 1, @@ -112,7 +118,7 @@ test('create complex training job', () => { dataSource: { s3DataSource: { s3DataType: tasks.S3DataType.S3_PREFIX, - s3Uri: "s3://mybucket/mytrainpath", + s3Location: S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'InputBucketA', 'mybucket'), 'mytrainpath'), } } }, @@ -124,20 +130,20 @@ test('create complex training job', () => { dataSource: { s3DataSource: { s3DataType: tasks.S3DataType.S3_PREFIX, - s3Uri: "s3://mybucket/mytestpath", + s3Location: S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'InputBucketB', 'mybucket'), 'mytestpath'), } } } ], outputDataConfig: { - s3OutputPath: 's3://mybucket/myoutputpath', + s3OutputLocation: S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'OutputBucket', 'mybucket'), 'myoutputpath'), encryptionKey: kmsKey }, resourceConfig: { instanceCount: 1, instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), volumeSizeInGB: 50, - volumeKmsKeyId: kmsKey, + volumeEncryptionKey: kmsKey, }, stoppingCondition: { maxRuntime: cdk.Duration.hours(1) @@ -179,7 +185,9 @@ test('create complex training job', () => { DataSource: { S3DataSource: { S3DataType: 'S3Prefix', - S3Uri: 's3://mybucket/mytrainpath' + S3Uri: { + 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region'}, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/mytrainpath']] + } } } }, @@ -191,13 +199,17 @@ test('create complex training job', () => { DataSource: { S3DataSource: { S3DataType: 'S3Prefix', - S3Uri: 's3://mybucket/mytestpath' + S3Uri: { + 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region'}, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/mytestpath']] + } } } } ], OutputDataConfig: { - S3OutputPath: 's3://mybucket/myoutputpath', + S3OutputPath: { + 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region' }, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/myoutputpath']] + }, KmsKeyId: { "Fn::GetAtt": [ "Key961B73FD", "Arn" ] }, }, ResourceConfig: { @@ -245,13 +257,13 @@ test('pass param to training job', () => { dataSource: { s3DataSource: { s3DataType: tasks.S3DataType.S3_PREFIX, - s3Uri: sfn.Data.stringAt('$.S3Bucket') + s3Location: S3Location.fromJsonExpression('$.S3Bucket') } } } ], outputDataConfig: { - s3OutputPath: 's3://mybucket/myoutputpath' + s3OutputLocation: S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath'), }, resourceConfig: { instanceCount: 1, @@ -287,7 +299,9 @@ test('pass param to training job', () => { } ], 'OutputDataConfig': { - S3OutputPath: 's3://mybucket/myoutputpath' + S3OutputPath: { + 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region' }, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/myoutputpath']] + } }, 'ResourceConfig': { InstanceCount: 1, diff --git a/packages/@aws-cdk/aws-stepfunctions/lib/fields.ts b/packages/@aws-cdk/aws-stepfunctions/lib/fields.ts index db2952ad43fae..2d35536cf9172 100644 --- a/packages/@aws-cdk/aws-stepfunctions/lib/fields.ts +++ b/packages/@aws-cdk/aws-stepfunctions/lib/fields.ts @@ -1,5 +1,5 @@ import { Token } from "@aws-cdk/core"; -import { findReferencedPaths, JsonPathToken, renderObject } from "./json-path"; +import { findReferencedPaths, jsonPathString, JsonPathToken, renderObject } from "./json-path"; /** * Extract a field from the State Machine data that gets passed around between states @@ -39,6 +39,10 @@ export class Data { return new JsonPathToken('$').toString(); } + public static isJsonPathString(value: string): boolean { + return !!jsonPathString(value); + } + private constructor() { } } @@ -134,4 +138,4 @@ function validateContextPath(path: string) { if (!path.startsWith('$$.')) { throw new Error("Context JSON path values must start with '$$.'"); } -} \ No newline at end of file +} diff --git a/packages/@aws-cdk/aws-stepfunctions/lib/json-path.ts b/packages/@aws-cdk/aws-stepfunctions/lib/json-path.ts index 717aa4bd06262..32aa168466038 100644 --- a/packages/@aws-cdk/aws-stepfunctions/lib/json-path.ts +++ b/packages/@aws-cdk/aws-stepfunctions/lib/json-path.ts @@ -3,7 +3,7 @@ import { captureStackTrace, IResolvable, IResolveContext, Token, Tokenization } const JSON_PATH_TOKEN_SYMBOL = Symbol.for('@aws-cdk/aws-stepfunctions.JsonPathToken'); export class JsonPathToken implements IResolvable { - public static isJsonPathToken(x: any): x is JsonPathToken { + public static isJsonPathToken(x: IResolvable): x is JsonPathToken { return (x as any)[JSON_PATH_TOKEN_SYMBOL] === true; } @@ -191,7 +191,7 @@ function renderBoolean(key: string, value: boolean): {[key: string]: boolean} { * * Otherwise return undefined. */ -function jsonPathString(x: string): string | undefined { +export function jsonPathString(x: string): string | undefined { const fragments = Tokenization.reverseString(x); const jsonPathTokens = fragments.tokens.filter(JsonPathToken.isJsonPathToken); @@ -224,4 +224,4 @@ function jsonPathNumber(x: number): string | undefined { function pathFromToken(token: IResolvable | undefined) { return token && (JsonPathToken.isJsonPathToken(token) ? token.path : undefined); -} \ No newline at end of file +} diff --git a/packages/@aws-cdk/core/lib/private/token-map.ts b/packages/@aws-cdk/core/lib/private/token-map.ts index 17e1ee33cebfa..1010fd753eef7 100644 --- a/packages/@aws-cdk/core/lib/private/token-map.ts +++ b/packages/@aws-cdk/core/lib/private/token-map.ts @@ -75,7 +75,7 @@ export class TokenMap { * Lookup a token from an encoded value */ public tokenFromEncoding(x: any): IResolvable | undefined { - if (typeof 'x' === 'string') { return this.lookupString(x); } + if (typeof x === 'string') { return this.lookupString(x); } if (Array.isArray(x)) { return this.lookupList(x); } if (Token.isUnresolved(x)) { return x; } return undefined; diff --git a/packages/@aws-cdk/core/lib/token.ts b/packages/@aws-cdk/core/lib/token.ts index ddb2856d756e1..6903b274fa1b2 100644 --- a/packages/@aws-cdk/core/lib/token.ts +++ b/packages/@aws-cdk/core/lib/token.ts @@ -161,4 +161,4 @@ export interface EncodingOptions { export function isResolvableObject(x: any): x is IResolvable { return typeof(x) === 'object' && x !== null && typeof x.resolve === 'function'; -} \ No newline at end of file +}