From 11ff939b5e2cdcb4d693467c27fb420214d01a89 Mon Sep 17 00:00:00 2001 From: Shiv Lakshminarayan Date: Tue, 2 Jun 2020 22:14:35 -0700 Subject: [PATCH 1/9] create constructs for sagemaker tasks --- .../lib/sagemaker/base-types.ts | 758 ++++++++++++++++++ .../lib/sagemaker/create-training-job.ts | 407 ++++++++++ .../lib/sagemaker/create-transform-job.ts | 271 +++++++ .../sagemaker/create-training-job.test.ts | 399 +++++++++ .../sagemaker/create-transform-job.test.ts | 243 ++++++ .../integ.create-training-job.expected.json | 402 ++++++++++ .../sagemaker/integ.create-training-job.ts | 33 + 7 files changed, 2513 insertions(+) create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-transform-job.test.ts create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.expected.json create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.ts diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts new file mode 100644 index 0000000000000..bbb9908fd58ef --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts @@ -0,0 +1,758 @@ +import * as ec2 from '@aws-cdk/aws-ec2'; +import * as ecr from '@aws-cdk/aws-ecr'; +import { DockerImageAsset, DockerImageAssetProps } from '@aws-cdk/aws-ecr-assets'; +import * as iam from '@aws-cdk/aws-iam'; +import * as kms from '@aws-cdk/aws-kms'; +import * as s3 from '@aws-cdk/aws-s3'; +import * as sfn from '@aws-cdk/aws-stepfunctions'; +import { Construct, Duration } from '@aws-cdk/core'; + +/** + * Task to train a machine learning model using Amazon SageMaker + * @experimental + */ +export interface ISageMakerTask extends sfn.TaskStateBase, iam.IGrantable {} + +/** + * Specify the training algorithm and algorithm-specific metadata + * @experimental + */ +export interface AlgorithmSpecification { + + /** + * Name of the algorithm resource to use for the training job. + * This must be an algorithm resource that you created or subscribe to on AWS Marketplace. + * If you specify a value for this parameter, you can't specify a value for TrainingImage. + * + * @default - No algorithm is specified + */ + readonly algorithmName?: string; + + /** + * List of metric definition objects. Each object specifies the metric name and regular expressions used to parse algorithm logs. + * + * @default - No metrics + */ + readonly metricDefinitions?: MetricDefinition[]; + + /** + * Registry path of the Docker image that contains the training algorithm. + * + * @default - No Docker image is specified + */ + readonly trainingImage?: DockerImage; + + /** + * Input mode that the algorithm supports. + * + * @default 'File' mode + */ + readonly trainingInputMode?: InputMode; +} + +/** + * Describes the training, validation or test dataset and the Amazon S3 location where it is stored. + * + * @experimental + */ +export interface Channel { + + /** + * Name of the channel + */ + readonly channelName: string; + + /** + * Compression type if training data is compressed + * + * @default - None + */ + readonly compressionType?: CompressionType; + + /** + * The MIME type of the data. + * + * @default - None + */ + readonly contentType?: string; + + /** + * Location of the channel data. + */ + readonly dataSource: DataSource; + + /** + * Input mode to use for the data channel in a training job. + * + * @default - None + */ + readonly inputMode?: InputMode; + + /** + * Specify RecordIO as the value when input data is in raw format but the training algorithm requires the RecordIO format. + * In this case, Amazon SageMaker wraps each individual S3 object in a RecordIO record. + * If the input data is already in RecordIO format, you don't need to set this attribute. + * + * @default - None + */ + readonly recordWrapperType?: RecordWrapperType; + + /** + * Shuffle config option for input data in a channel. + * + * @default - None + */ + readonly shuffleConfig?: ShuffleConfig; +} + +/** + * Configuration for a shuffle option for input data in a channel. + * + * @experimental + */ +export interface ShuffleConfig { + /** + * Determines the shuffling order. + */ + readonly seed: number; +} + +/** + * Location of the channel data. + * + * @experimental + */ +export interface DataSource { + /** + * S3 location of the data source that is associated with a channel. + */ + readonly s3DataSource: S3DataSource; +} + +/** + * S3 location of the channel data. + * + * @see https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3DataSource.html + * + * @experimental + */ +export interface S3DataSource { + /** + * List of one or more attribute names to use that are found in a specified augmented manifest file. + * + * @default - No attribute names + */ + readonly attributeNames?: string[]; + + /** + * S3 Data Distribution Type + * + * @default - None + */ + readonly s3DataDistributionType?: S3DataDistributionType; + + /** + * S3 Data Type + * + * @default S3_PREFIX + */ + readonly s3DataType?: S3DataType; + + /** + * S3 Uri + */ + readonly s3Location: S3Location; +} + +/** + * Configures the S3 bucket where SageMaker will save the result of model training + * @experimental + */ +export interface OutputDataConfig { + /** + * Optional KMS encryption key that Amazon SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side encryption. + * + * @default - Amazon SageMaker uses the default KMS key for Amazon S3 for your role's account + */ + readonly encryptionKey?: kms.IKey; + + /** + * Identifies the S3 path where you want Amazon SageMaker to store the model artifacts. + */ + readonly s3OutputLocation: S3Location; +} + +/** + * Specifies a limit to how long a model training job can run. + * When the job reaches the time limit, Amazon SageMaker ends the training job. + * + * @experimental + */ +export interface StoppingCondition { + /** + * The maximum length of time, in seconds, that the training or compilation job can run. + * + * @default - 1 hour + */ + readonly maxRuntime?: Duration; +} + +/** + * Specifies the resources, ML compute instances, and ML storage volumes to deploy for model training. + * + * @experimental + */ +export interface ResourceConfig { + + /** + * The number of ML compute instances to use. + * + * @default 1 instance. + */ + readonly instanceCount: number; + + /** + * ML compute instance type. + * + * @default is the 'm4.xlarge' instance type. + */ + readonly instanceType: ec2.InstanceType; + + /** + * KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) that run the training job. + * + * @default - Amazon SageMaker uses the default KMS key for Amazon S3 for your role's account + */ + readonly volumeEncryptionKey?: kms.IKey; + + /** + * Size of the ML storage volume that you want to provision. + * + * @default 10 GB EBS volume. + */ + readonly volumeSizeInGB: number; +} + +/** + * Specifies the VPC that you want your Amazon SageMaker training job to connect to. + * + * @experimental + */ +export interface VpcConfig { + /** + * VPC + */ + readonly vpc: ec2.IVpc; + + /** + * VPC subnets. + * + * @default - Private Subnets are selected + */ + readonly subnets?: ec2.SubnetSelection; +} + +/** + * Specifies the metric name and regular expressions used to parse algorithm logs. + * + * @experimental + */ +export interface MetricDefinition { + + /** + * Name of the metric. + */ + readonly name: string; + + /** + * Regular expression that searches the output of a training job and gets the value of the metric. + */ + readonly regex: string; +} + +/** + * Stores information about the location of an object in Amazon S3 + * + * @experimental + */ +export interface S3LocationConfig { + + /** + * Uniquely identifies the resource in Amazon S3 + */ + readonly uri: string; +} + +/** + * Constructs `IS3Location` objects. + * + * @experimental + */ +export abstract class S3Location { + /** + * An `IS3Location` built with a determined bucket and key prefix. + * + * @param bucket is the bucket where the objects are to be stored. + * @param keyPrefix is the key prefix used by the location. + */ + public static fromBucket(bucket: s3.IBucket, keyPrefix: string): S3Location { + return new StandardS3Location({ bucket, keyPrefix, uri: bucket.urlForObject(keyPrefix) }); + } + + /** + * An `IS3Location` determined fully by a JSON Path from the task input. + * + * Due to the dynamic nature of those locations, the IAM grants that will be set by `grantRead` and `grantWrite` + * apply to the `*` resource. + * + * @param expression the JSON expression resolving to an S3 location URI. + */ + public static fromJsonExpression(expression: string): S3Location { + return new StandardS3Location({ uri: sfn.Data.stringAt(expression) }); + } + + /** + * Called when the S3Location is bound to a StepFunctions task. + */ + public abstract bind(task: ISageMakerTask, opts: S3LocationBindOptions): S3LocationConfig; +} + +/** + * Options for binding an S3 Location. + * + * @experimental + */ +export interface S3LocationBindOptions { + /** + * Allow reading from the S3 Location. + * + * @default false + */ + readonly forReading?: boolean; + + /** + * Allow writing to the S3 Location. + * + * @default false + */ + readonly forWriting?: boolean; +} + +/** + * Configuration for a using Docker image. + * + * @experimental + */ +export interface DockerImageConfig { + /** + * The fully qualified URI of the Docker image. + */ + readonly imageUri: string; +} + +/** + * Creates `IDockerImage` instances. + * + * @experimental + */ +export abstract class DockerImage { + /** + * Reference a Docker image stored in an ECR repository. + * + * @param repository the ECR repository where the image is hosted. + * @param tag an optional `tag` + */ + public static fromEcrRepository(repository: ecr.IRepository, tag: string = 'latest'): DockerImage { + return new StandardDockerImage({ repository, imageUri: repository.repositoryUriForTag(tag) }); + } + + /** + * Reference a Docker image which URI is obtained from the task's input. + * + * @param expression the JSON path expression with the task input. + * @param allowAnyEcrImagePull whether ECR access should be permitted (set to `false` if the image will never be in ECR). + */ + public static fromJsonExpression(expression: string, allowAnyEcrImagePull = true): DockerImage { + return new StandardDockerImage({ imageUri: expression, allowAnyEcrImagePull }); + } + + /** + * Reference a Docker image by it's URI. + * + * When referencing ECR images, prefer using `inEcr`. + * + * @param imageUri the URI to the docker image. + */ + public static fromRegistry(imageUri: string): DockerImage { + return new StandardDockerImage({ imageUri }); + } + + /** + * Reference a Docker image that is provided as an Asset in the current app. + * + * @param scope the scope in which to create the Asset. + * @param id the ID for the asset in the construct tree. + * @param props the configuration props of the asset. + */ + public static fromAsset(scope: Construct, id: string, props: DockerImageAssetProps): DockerImage { + const asset = new DockerImageAsset(scope, id, props); + return new StandardDockerImage({ repository: asset.repository, imageUri: asset.imageUri }); + } + + /** + * Called when the image is used by a SageMaker task. + */ + public abstract bind(task: ISageMakerTask): DockerImageConfig; +} + +/** + * S3 Data Type. + * + * @experimental + */ +export enum S3DataType { + /** + * Manifest File Data Type + */ + MANIFEST_FILE = 'ManifestFile', + + /** + * S3 Prefix Data Type + */ + S3_PREFIX = 'S3Prefix', + + /** + * Augmented Manifest File Data Type + */ + AUGMENTED_MANIFEST_FILE = 'AugmentedManifestFile' +} + +/** + * S3 Data Distribution Type. + * + * @experimental + */ +export enum S3DataDistributionType { + /** + * Fully replicated S3 Data Distribution Type + */ + FULLY_REPLICATED = 'FullyReplicated', + + /** + * Sharded By S3 Key Data Distribution Type + */ + SHARDED_BY_S3_KEY = 'ShardedByS3Key' +} + +/** + * Define the format of the input data. + * + * @experimental + */ +export enum RecordWrapperType { + /** + * None record wrapper type + */ + NONE = 'None', + + /** + * RecordIO record wrapper type + */ + RECORD_IO = 'RecordIO' +} + +/** + * Input mode that the algorithm supports. + * + * @experimental + */ +export enum InputMode { + /** + * Pipe mode + */ + PIPE = 'Pipe', + + /** + * File mode. + */ + FILE = 'File' +} + +/** + * Compression type of the data. + * + * @experimental + */ +export enum CompressionType { + /** + * None compression type + */ + NONE = 'None', + + /** + * Gzip compression type + */ + GZIP = 'Gzip' +} + +// +// Create Transform Job types +// + +/** + * Dataset to be transformed and the Amazon S3 location where it is stored. + * + * @experimental + */ +export interface TransformInput { + + /** + * The compression type of the transform data. + * + * @default NONE + */ + readonly compressionType?: CompressionType; + + /** + * Multipurpose internet mail extension (MIME) type of the data. + * + * @default - None + */ + readonly contentType?: string; + + /** + * S3 location of the channel data + */ + readonly transformDataSource: TransformDataSource; + + /** + * Method to use to split the transform job's data files into smaller batches. + * + * @default NONE + */ + readonly splitType?: SplitType; +} + +/** + * S3 location of the input data that the model can consume. + * + * @experimental + */ +export interface TransformDataSource { + + /** + * S3 location of the input data + */ + readonly s3DataSource: TransformS3DataSource; +} + +/** + * Location of the channel data. + * + * @experimental + */ +export interface TransformS3DataSource { + + /** + * S3 Data Type + * + * @default 'S3Prefix' + */ + readonly s3DataType?: S3DataType; + + /** + * Identifies either a key name prefix or a manifest. + */ + readonly s3Uri: string; +} + +/** + * S3 location where you want Amazon SageMaker to save the results from the transform job. + * + * @experimental + */ +export interface TransformOutput { + + /** + * MIME type used to specify the output data. + * + * @default - None + */ + readonly accept?: string; + + /** + * Defines how to assemble the results of the transform job as a single S3 object. + * + * @default - None + */ + readonly assembleWith?: AssembleWith; + + /** + * AWS KMS key that Amazon SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side encryption. + * + * @default - default KMS key for Amazon S3 for your role's account. + */ + readonly encryptionKey?: kms.IKey; + + /** + * S3 path where you want Amazon SageMaker to store the results of the transform job. + */ + readonly s3OutputPath: string; +} + +/** + * ML compute instances for the transform job. + * + * @experimental + */ +export interface TransformResources { + + /** + * Number of ML compute instances to use in the transform job + */ + readonly instanceCount: number; + + /** + * ML compute instance type for the transform job. + */ + readonly instanceType: ec2.InstanceType; + + /** + * AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s). + * + * @default - None + */ + readonly volumeKmsKeyId?: kms.Key; +} + +/** + * Specifies the number of records to include in a mini-batch for an HTTP inference request. + * + * @experimental + */ +export enum BatchStrategy { + + /** + * Fits multiple records in a mini-batch. + */ + MULTI_RECORD = 'MultiRecord', + + /** + * Use a single record when making an invocation request. + */ + SINGLE_RECORD = 'SingleRecord' +} + +/** + * Method to use to split the transform job's data files into smaller batches. + * + * @experimental + */ +export enum SplitType { + + /** + * Input data files are not split, + */ + NONE = 'None', + + /** + * Split records on a newline character boundary. + */ + LINE = 'Line', + + /** + * Split using MXNet RecordIO format. + */ + RECORD_IO = 'RecordIO', + + /** + * Split using TensorFlow TFRecord format. + */ + TF_RECORD = 'TFRecord' +} + +/** + * How to assemble the results of the transform job as a single S3 object. + * + * @experimental + */ +export enum AssembleWith { + + /** + * Concatenate the results in binary format. + */ + NONE = 'None', + + /** + * Add a newline character at the end of every transformed record. + */ + LINE = 'Line' + +} + +class StandardDockerImage extends DockerImage { + private readonly allowAnyEcrImagePull: boolean; + private readonly imageUri: string; + private readonly repository?: ecr.IRepository; + + constructor(opts: { allowAnyEcrImagePull?: boolean, imageUri: string, repository?: ecr.IRepository }) { + super(); + + this.allowAnyEcrImagePull = !!opts.allowAnyEcrImagePull; + this.imageUri = opts.imageUri; + this.repository = opts.repository; + } + + public bind(task: ISageMakerTask): DockerImageConfig { + if (this.repository) { + this.repository.grantPull(task); + } + if (this.allowAnyEcrImagePull) { + task.grantPrincipal.addToPolicy(new iam.PolicyStatement({ + actions: [ + 'ecr:BatchCheckLayerAvailability', + 'ecr:GetDownloadUrlForLayer', + 'ecr:BatchGetImage', + ], + resources: ['*'], + })); + } + return { + imageUri: this.imageUri, + }; + } +} + +class StandardS3Location extends S3Location { + private readonly bucket?: s3.IBucket; + private readonly keyGlob: string; + private readonly uri: string; + + constructor(opts: { bucket?: s3.IBucket, keyPrefix?: string, uri: string }) { + super(); + this.bucket = opts.bucket; + this.keyGlob = `${opts.keyPrefix || ''}*`; + this.uri = opts.uri; + } + + public bind(task: ISageMakerTask, opts: S3LocationBindOptions): S3LocationConfig { + if (this.bucket) { + if (opts.forReading) { + this.bucket.grantRead(task, this.keyGlob); + } + if (opts.forWriting) { + this.bucket.grantWrite(task, this.keyGlob); + } + } else { + const actions = new Array(); + if (opts.forReading) { + actions.push('s3:GetObject', 's3:ListBucket'); + } + if (opts.forWriting) { + actions.push('s3:PutObject'); + } + task.grantPrincipal.addToPolicy(new iam.PolicyStatement({ actions, resources: ['*'] })); + } + return { uri: this.uri }; + } +} diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts new file mode 100644 index 0000000000000..5a66d346c2c5e --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts @@ -0,0 +1,407 @@ +import * as ec2 from '@aws-cdk/aws-ec2'; +import * as iam from '@aws-cdk/aws-iam'; +import * as sfn from '@aws-cdk/aws-stepfunctions'; +import { Construct, Duration, Lazy, Stack } from '@aws-cdk/core'; +import { integrationResourceArn, validatePatternSupported } from '../private/task-utils'; +import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig, + S3DataType, StoppingCondition, VpcConfig } from './sagemaker-task-base-types'; + +/** + * Properties for creating an Amazon SageMaker training job + * + * @experimental + */ +export interface SageMakerCreateTrainingJobProps extends sfn.TaskStateBaseProps { + + /** + * Training Job Name. + */ + readonly trainingJobName: string; + + /** + * Role for the Training Job. The role must be granted all necessary permissions for the SageMaker training job to + * be able to operate. + * + * See https://docs.aws.amazon.com/fr_fr/sagemaker/latest/dg/sagemaker-roles.html#sagemaker-roles-createtrainingjob-perms + * + * @default - a role with appropriate permissions will be created. + */ + readonly role?: iam.IRole; + + /** + * Identifies the training algorithm to use. + */ + readonly algorithmSpecification: AlgorithmSpecification; + + /** + * Algorithm-specific parameters that influence the quality of the model. Set hyperparameters before you start the learning process. + * For a list of hyperparameters provided by Amazon SageMaker + * @see https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html + * + * @default - No hyperparameters + */ + readonly hyperparameters?: {[key: string]: any}; + + /** + * Describes the various datasets (e.g. train, validation, test) and the Amazon S3 location where stored. + */ + readonly inputDataConfig: Channel[]; + + /** + * Tags to be applied to the train job. + * + * @default - No tags + */ + readonly tags?: {[key: string]: string}; + + /** + * Identifies the Amazon S3 location where you want Amazon SageMaker to save the results of model training. + */ + readonly outputDataConfig: OutputDataConfig; + + /** + * Specifies the resources, ML compute instances, and ML storage volumes to deploy for model training. + * + * @default - 1 instance of EC2 `M4.XLarge` with `10GB` volume + */ + readonly resourceConfig?: ResourceConfig; + + /** + * Sets a time limit for training. + * + * @default - max runtime of 1 hour + */ + readonly stoppingCondition?: StoppingCondition; + + /** + * Specifies the VPC that you want your training job to connect to. + * + * @default - No VPC + */ + readonly vpcConfig?: VpcConfig; +} + +/** + * Class representing the SageMaker Create Training Job task. + * + * @experimental + */ +export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam.IGrantable, ec2.IConnectable { + + private static readonly SUPPORTED_INTEGRATION_PATTERNS: sfn.IntegrationPattern[] = [ + sfn.IntegrationPattern.REQUEST_RESPONSE, + sfn.IntegrationPattern.RUN_JOB, + ]; + + /** + * Allows specify security group connections for instances of this fleet. + */ + public readonly connections: ec2.Connections = new ec2.Connections(); + + protected readonly taskPolicies?: iam.PolicyStatement[]; + protected readonly taskMetrics?: sfn.TaskMetricsConfig; + + /** + * The Algorithm Specification + */ + private readonly algorithmSpecification: AlgorithmSpecification; + + /** + * The Input Data Config. + */ + private readonly inputDataConfig: Channel[]; + + /** + * The resource config for the task. + */ + private readonly resourceConfig: ResourceConfig; + + /** + * The resource config for the task. + */ + private readonly stoppingCondition: StoppingCondition; + + private readonly vpc?: ec2.IVpc; + private securityGroup?: ec2.ISecurityGroup; + private readonly securityGroups: ec2.ISecurityGroup[] = []; + private readonly subnets?: string[]; + private readonly integrationPattern: sfn.IntegrationPattern; + private _role?: iam.IRole; + private _grantPrincipal?: iam.IPrincipal; + + constructor(scope: Construct, id: string, private readonly props: SageMakerCreateTrainingJobProps) { + super(scope, id, props); + + this.integrationPattern = props.integrationPattern || sfn.IntegrationPattern.REQUEST_RESPONSE; + validatePatternSupported(this.integrationPattern, SageMakerCreateTrainingJob.SUPPORTED_INTEGRATION_PATTERNS); + + // set the default resource config if not defined. + this.resourceConfig = props.resourceConfig || { + instanceCount: 1, + instanceType: ec2.InstanceType.of(ec2.InstanceClass.M4, ec2.InstanceSize.XLARGE), + volumeSizeInGB: 10, + }; + + // set the stopping condition if not defined + this.stoppingCondition = props.stoppingCondition || { + maxRuntime: Duration.hours(1), + }; + + // check that either algorithm name or image is defined + if ((!props.algorithmSpecification.algorithmName) && (!props.algorithmSpecification.trainingImage)) { + throw new Error('Must define either an algorithm name or training image URI in the algorithm specification'); + } + + // set the input mode to 'File' if not defined + this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ? + ( props.algorithmSpecification ) : + ( { ...props.algorithmSpecification, trainingInputMode: InputMode.FILE } ); + + // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined + this.inputDataConfig = props.inputDataConfig.map(config => { + if (!config.dataSource.s3DataSource.s3DataType) { + return Object.assign({}, config, { dataSource: { s3DataSource: + { ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } } }); + } else { + return config; + } + }); + + // add the security groups to the connections object + if (props.vpcConfig) { + this.vpc = props.vpcConfig.vpc; + this.subnets = (props.vpcConfig.subnets) ? + (this.vpc.selectSubnets(props.vpcConfig.subnets).subnetIds) : this.vpc.selectSubnets().subnetIds; + } + + this.taskPolicies = this.makePolicyStatements(); + } + + /** + * The execution role for the Sagemaker training job. + * + * Only available after task has been added to a state machine. + */ + public get role(): iam.IRole { + if (this._role === undefined) { + throw new Error('role not available yet--use the object in a Task first'); + } + return this._role; + } + + public get grantPrincipal(): iam.IPrincipal { + if (this._grantPrincipal === undefined) { + throw new Error('Principal not available yet--use the object in a Task first'); + } + return this._grantPrincipal; + } + + /** + * Add the security group to all instances via the launch configuration + * security groups array. + * + * @param securityGroup: The security group to add + */ + public addSecurityGroup(securityGroup: ec2.ISecurityGroup): void { + this.securityGroups.push(securityGroup); + } + + public renderTask(): any { + return this.bind(); + } + + public bind(): any { + return { + Resource: integrationResourceArn('sagemaker', 'createTrainingJob', this.integrationPattern), + Parameters: sfn.FieldUtils.renderObject(this.renderParameters()), + }; + } + + private renderParameters(): {[key: string]: any} { + return { + TrainingJobName: this.props.trainingJobName, + RoleArn: this._role!.roleArn, + ...(this.renderAlgorithmSpecification(this.algorithmSpecification)), + ...(this.renderInputDataConfig(this.inputDataConfig)), + ...(this.renderOutputDataConfig(this.props.outputDataConfig)), + ...(this.renderResourceConfig(this.resourceConfig)), + ...(this.renderStoppingCondition(this.stoppingCondition)), + ...(this.renderHyperparameters(this.props.hyperparameters)), + ...(this.renderTags(this.props.tags)), + ...(this.renderVpcConfig(this.props.vpcConfig)), + }; + } + + private renderAlgorithmSpecification(spec: AlgorithmSpecification): {[key: string]: any} { + return { + AlgorithmSpecification: { + TrainingInputMode: spec.trainingInputMode, + ...(spec.trainingImage) ? { TrainingImage: spec.trainingImage.bind(this).imageUri } : {}, + ...(spec.algorithmName) ? { AlgorithmName: spec.algorithmName } : {}, + ...(spec.metricDefinitions) ? + { MetricDefinitions: spec.metricDefinitions + .map(metric => ({ Name: metric.name, Regex: metric.regex })) } : {}, + }, + }; + } + + private renderInputDataConfig(config: Channel[]): {[key: string]: any} { + return { + InputDataConfig: config.map(channel => ({ + ChannelName: channel.channelName, + DataSource: { + S3DataSource: { + S3Uri: channel.dataSource.s3DataSource.s3Location.bind(this, { forReading: true }).uri, + S3DataType: channel.dataSource.s3DataSource.s3DataType, + ...(channel.dataSource.s3DataSource.s3DataDistributionType) ? + { S3DataDistributionType: channel.dataSource.s3DataSource.s3DataDistributionType} : {}, + ...(channel.dataSource.s3DataSource.attributeNames) ? + { AtttributeNames: channel.dataSource.s3DataSource.attributeNames } : {}, + }, + }, + ...(channel.compressionType) ? { CompressionType: channel.compressionType } : {}, + ...(channel.contentType) ? { ContentType: channel.contentType } : {}, + ...(channel.inputMode) ? { InputMode: channel.inputMode } : {}, + ...(channel.recordWrapperType) ? { RecordWrapperType: channel.recordWrapperType } : {}, + })), + }; + } + + private renderOutputDataConfig(config: OutputDataConfig): {[key: string]: any} { + return { + OutputDataConfig: { + S3OutputPath: config.s3OutputLocation.bind(this, { forWriting: true }).uri, + ...(config.encryptionKey) ? { KmsKeyId: config.encryptionKey.keyArn } : {}, + }, + }; + } + + private renderResourceConfig(config: ResourceConfig): {[key: string]: any} { + return { + ResourceConfig: { + InstanceCount: config.instanceCount, + InstanceType: 'ml.' + config.instanceType, + VolumeSizeInGB: config.volumeSizeInGB, + ...(config.volumeEncryptionKey) ? { VolumeKmsKeyId: config.volumeEncryptionKey.keyArn } : {}, + }, + }; + } + + private renderStoppingCondition(config: StoppingCondition): {[key: string]: any} { + return { + StoppingCondition: { + MaxRuntimeInSeconds: config.maxRuntime && config.maxRuntime.toSeconds(), + }, + }; + } + + private renderHyperparameters(params: {[key: string]: any} | undefined): {[key: string]: any} { + return (params) ? { HyperParameters: params } : {}; + } + + private renderTags(tags: {[key: string]: any} | undefined): {[key: string]: any} { + return (tags) ? { Tags: Object.keys(tags).map(key => ({ Key: key, Value: tags[key] })) } : {}; + } + + private renderVpcConfig(config: VpcConfig | undefined): {[key: string]: any} { + return (config) ? { VpcConfig: { + SecurityGroupIds: Lazy.listValue({ produce: () => (this.securityGroups.map(sg => (sg.securityGroupId))) }), + Subnets: this.subnets, + }} : {}; + } + + private makePolicyStatements(): iam.PolicyStatement[] { + // set the sagemaker role or create new one + this._grantPrincipal = this._role = this.props.role || new iam.Role(this, 'SagemakerRole', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + inlinePolicies: { + CreateTrainingJob: new iam.PolicyDocument({ + statements: [ + new iam.PolicyStatement({ + actions: [ + 'cloudwatch:PutMetricData', + 'logs:CreateLogStream', + 'logs:PutLogEvents', + 'logs:CreateLogGroup', + 'logs:DescribeLogStreams', + 'ecr:GetAuthorizationToken', + ...this.props.vpcConfig + ? [ + 'ec2:CreateNetworkInterface', + 'ec2:CreateNetworkInterfacePermission', + 'ec2:DeleteNetworkInterface', + 'ec2:DeleteNetworkInterfacePermission', + 'ec2:DescribeNetworkInterfaces', + 'ec2:DescribeVpcs', + 'ec2:DescribeDhcpOptions', + 'ec2:DescribeSubnets', + 'ec2:DescribeSecurityGroups', + ] + : [], + ], + resources: ['*'], // Those permissions cannot be resource-scoped + }), + ], + }), + }, + }); + + if (this.props.outputDataConfig.encryptionKey) { + this.props.outputDataConfig.encryptionKey.grantEncrypt(this._role); + } + + if (this.props.resourceConfig && this.props.resourceConfig.volumeEncryptionKey) { + this.props.resourceConfig.volumeEncryptionKey.grant(this._role, 'kms:CreateGrant'); + } + + // create a security group if not defined + if (this.vpc && this.securityGroup === undefined) { + this.securityGroup = new ec2.SecurityGroup(this, 'TrainJobSecurityGroup', { + vpc: this.vpc, + }); + this.connections.addSecurityGroup(this.securityGroup); + this.securityGroups.push(this.securityGroup); + } + + const stack = Stack.of(this); + + // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html + const policyStatements = [ + new iam.PolicyStatement({ + actions: ['sagemaker:CreateTrainingJob', 'sagemaker:DescribeTrainingJob', 'sagemaker:StopTrainingJob'], + resources: [ + stack.formatArn({ + service: 'sagemaker', + resource: 'training-job', + // If the job name comes from input, we cannot target the policy to a particular ARN prefix reliably... + resourceName: sfn.Data.isJsonPathString(this.props.trainingJobName) ? '*' : `${this.props.trainingJobName}*`, + }), + ], + }), + new iam.PolicyStatement({ + actions: ['sagemaker:ListTags'], + resources: ['*'], + }), + new iam.PolicyStatement({ + actions: ['iam:PassRole'], + resources: [this._role!.roleArn], + conditions: { + StringEquals: { 'iam:PassedToService': 'sagemaker.amazonaws.com' }, + }, + }), + ]; + + if (this.integrationPattern === sfn.IntegrationPattern.RUN_JOB) { + policyStatements.push(new iam.PolicyStatement({ + actions: ['events:PutTargets', 'events:PutRule', 'events:DescribeRule'], + resources: [stack.formatArn({ + service: 'events', + resource: 'rule', + resourceName: 'StepFunctionsGetEventsForSageMakerTrainingJobsRule', + })], + })); + } + + return policyStatements; + } +} diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts new file mode 100644 index 0000000000000..eb80f742dcb42 --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts @@ -0,0 +1,271 @@ +import * as ec2 from '@aws-cdk/aws-ec2'; +import * as iam from '@aws-cdk/aws-iam'; +import * as sfn from '@aws-cdk/aws-stepfunctions'; +import { Construct, Stack } from '@aws-cdk/core'; +import { integrationResourceArn, validatePatternSupported } from '../private/task-utils'; +import { BatchStrategy, S3DataType, TransformInput, TransformOutput, TransformResources } from './sagemaker-task-base-types'; + +/** + * Properties for creating an Amazon SageMaker training job task + * + * @experimental + */ +export interface SageMakerCreateTransformJobProps extends sfn.TaskStateBaseProps { + + /** + * Training Job Name. + */ + readonly transformJobName: string; + + /** + * Role for the Training Job. + * + * @default - A role is created with `AmazonSageMakerFullAccess` managed policy + */ + readonly role?: iam.IRole; + + /** + * Number of records to include in a mini-batch for an HTTP inference request. + * + * @default - No batch strategy + */ + readonly batchStrategy?: BatchStrategy; + + /** + * Environment variables to set in the Docker container. + * + * @default - No environment variables + */ + readonly environment?: {[key: string]: string}; + + /** + * Maximum number of parallel requests that can be sent to each instance in a transform job. + * + * @default - Amazon SageMaker checks the optional execution-parameters to determine the settings for your chosen algorithm. + * If the execution-parameters endpoint is not enabled, the default value is 1. + */ + readonly maxConcurrentTransforms?: number; + + /** + * Maximum allowed size of the payload, in MB. + * + * @default 6 + */ + readonly maxPayloadInMB?: number; + + /** + * Name of the model that you want to use for the transform job. + */ + readonly modelName: string; + + /** + * Tags to be applied to the train job. + * + * @default - No tags + */ + readonly tags?: {[key: string]: string}; + + /** + * Dataset to be transformed and the Amazon S3 location where it is stored. + */ + readonly transformInput: TransformInput; + + /** + * S3 location where you want Amazon SageMaker to save the results from the transform job. + */ + readonly transformOutput: TransformOutput; + + /** + * ML compute instances for the transform job. + * + * @default - 1 instance of type M4.XLarge + */ + readonly transformResources?: TransformResources; +} + +/** + * Class representing the SageMaker Create Training Job task. + * + * @experimental + */ +export class SageMakerCreateTransformJob extends sfn.TaskStateBase { + + private static readonly SUPPORTED_INTEGRATION_PATTERNS: sfn.IntegrationPattern[] = [ + sfn.IntegrationPattern.REQUEST_RESPONSE, + sfn.IntegrationPattern.RUN_JOB, + ]; + + protected readonly taskPolicies?: iam.PolicyStatement[]; + protected readonly taskMetrics?: sfn.TaskMetricsConfig; + + /** + * Dataset to be transformed and the Amazon S3 location where it is stored. + */ + private readonly transformInput: TransformInput; + + /** + * ML compute instances for the transform job. + */ + private readonly transformResources: TransformResources; + private readonly integrationPattern: sfn.IntegrationPattern; + private _role?: iam.IRole; + + constructor(scope: Construct, id: string, private readonly props: SageMakerCreateTransformJobProps) { + super(scope, id, props); + this.integrationPattern = props.integrationPattern || sfn.IntegrationPattern.REQUEST_RESPONSE; + validatePatternSupported(this.integrationPattern, SageMakerCreateTransformJob.SUPPORTED_INTEGRATION_PATTERNS); + + // set the sagemaker role or create new one + if (props.role) { + this._role = props.role; + } + + // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined + this.transformInput = (props.transformInput.transformDataSource.s3DataSource.s3DataType) ? (props.transformInput) : + Object.assign({}, props.transformInput, + { transformDataSource: + { s3DataSource: + { ...props.transformInput.transformDataSource.s3DataSource, + s3DataType: S3DataType.S3_PREFIX, + }, + }, + }); + + // set the default value for the transform resources + this.transformResources = props.transformResources || { + instanceCount: 1, + instanceType: ec2.InstanceType.of(ec2.InstanceClass.M4, ec2.InstanceSize.XLARGE), + }; + + this.taskPolicies = this.makePolicyStatements(); + } + + public renderTask(): any { + return { + Resource: integrationResourceArn('sagemaker', 'createTransformJob', this.integrationPattern), + Parameters: sfn.FieldUtils.renderObject(this.renderParameters()), + }; + } + + /** + * The execution role for the Sagemaker training job. + * + * Only available after task has been added to a state machine. + */ + public get role(): iam.IRole { + if (this._role === undefined) { + throw new Error('role not available yet--use the object in a Task first'); + } + return this._role; + } + + private renderParameters(): {[key: string]: any} { + return { + ...(this.props.batchStrategy) ? { BatchStrategy: this.props.batchStrategy } : {}, + ...(this.renderEnvironment(this.props.environment)), + ...(this.props.maxConcurrentTransforms) ? { MaxConcurrentTransforms: this.props.maxConcurrentTransforms } : {}, + ...(this.props.maxPayloadInMB) ? { MaxPayloadInMB: this.props.maxPayloadInMB } : {}, + ModelName: this.props.modelName, + ...(this.renderTags(this.props.tags)), + ...(this.renderTransformInput(this.transformInput)), + TransformJobName: this.props.transformJobName, + ...(this.renderTransformOutput(this.props.transformOutput)), + ...(this.renderTransformResources(this.transformResources)), + }; + } + + private renderTransformInput(input: TransformInput): {[key: string]: any} { + return { + TransformInput: { + ...(input.compressionType) ? { CompressionType: input.compressionType } : {}, + ...(input.contentType) ? { ContentType: input.contentType } : {}, + DataSource: { + S3DataSource: { + S3Uri: input.transformDataSource.s3DataSource.s3Uri, + S3DataType: input.transformDataSource.s3DataSource.s3DataType, + }, + }, + ...(input.splitType) ? { SplitType: input.splitType } : {}, + }, + }; + } + + private renderTransformOutput(output: TransformOutput): {[key: string]: any} { + return { + TransformOutput: { + S3OutputPath: output.s3OutputPath, + ...(output.encryptionKey) ? { KmsKeyId: output.encryptionKey.keyArn } : {}, + ...(output.accept) ? { Accept: output.accept } : {}, + ...(output.assembleWith) ? { AssembleWith: output.assembleWith } : {}, + }, + }; + } + + private renderTransformResources(resources: TransformResources): {[key: string]: any} { + return { + TransformResources: { + InstanceCount: resources.instanceCount, + InstanceType: 'ml.' + resources.instanceType, + ...(resources.volumeKmsKeyId) ? { VolumeKmsKeyId: resources.volumeKmsKeyId.keyArn } : {}, + }, + }; + } + + private renderEnvironment(environment: {[key: string]: any} | undefined): {[key: string]: any} { + return (environment) ? { Environment: environment } : {}; + } + + private renderTags(tags: {[key: string]: any} | undefined): {[key: string]: any} { + return (tags) ? { Tags: Object.keys(tags).map(key => ({ Key: key, Value: tags[key] })) } : {}; + } + + private makePolicyStatements(): iam.PolicyStatement[] { + const stack = Stack.of(this); + + // create new role if doesn't exist + if (this._role === undefined) { + this._role = new iam.Role(this, 'SagemakerTransformRole', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicies: [ + iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess'), + ], + }); + } + + // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html + const policyStatements = [ + new iam.PolicyStatement({ + actions: ['sagemaker:CreateTransformJob', 'sagemaker:DescribeTransformJob', 'sagemaker:StopTransformJob'], + resources: [stack.formatArn({ + service: 'sagemaker', + resource: 'transform-job', + resourceName: '*', + })], + }), + new iam.PolicyStatement({ + actions: ['sagemaker:ListTags'], + resources: ['*'], + }), + new iam.PolicyStatement({ + actions: ['iam:PassRole'], + resources: [this.role.roleArn], + conditions: { + StringEquals: { 'iam:PassedToService': 'sagemaker.amazonaws.com' }, + }, + }), + ]; + + if (this.integrationPattern === sfn.IntegrationPattern.RUN_JOB) { + policyStatements.push(new iam.PolicyStatement({ + actions: ['events:PutTargets', 'events:PutRule', 'events:DescribeRule'], + resources: [stack.formatArn({ + service: 'events', + resource: 'rule', + resourceName: 'StepFunctionsGetEventsForSageMakerTransformJobsRule', + }) ], + })); + } + + return policyStatements; + } +} diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts new file mode 100644 index 0000000000000..a186a4e4917fb --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts @@ -0,0 +1,399 @@ +import '@aws-cdk/assert/jest'; +import * as ec2 from '@aws-cdk/aws-ec2'; +import * as iam from '@aws-cdk/aws-iam'; +import * as kms from '@aws-cdk/aws-kms'; +import * as s3 from '@aws-cdk/aws-s3'; +import * as sfn from '@aws-cdk/aws-stepfunctions'; +import * as cdk from '@aws-cdk/core'; +import * as tasks from '../../lib'; +import { SageMakerCreateTrainingJob } from '../../lib/sagemaker/create-training-job'; + +let stack: cdk.Stack; + +beforeEach(() => { + // GIVEN + stack = new cdk.Stack(); +}); + +test('create basic training job', () => { + // WHEN + const task = new SageMakerCreateTrainingJob(stack, 'TrainSagemaker', { + trainingJobName: 'MyTrainJob', + algorithmSpecification: { + algorithmName: 'BlazingText', + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3Location: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'InputBucket', 'mybucket'), 'mytrainpath'), + }, + }, + }, + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'OutputBucket', 'mybucket'), 'myoutputpath'), + }, + }); + + // THEN + expect(stack.resolve(task.toStateJson())).toEqual({ + Type: 'Task', + Resource: { + 'Fn::Join': [ + '', + [ + 'arn:', + { + Ref: 'AWS::Partition', + }, + ':states:::sagemaker:createTrainingJob', + ], + ], + }, + End: true, + Parameters: { + AlgorithmSpecification: { + AlgorithmName: 'BlazingText', + TrainingInputMode: 'File', + }, + InputDataConfig: [ + { + ChannelName: 'train', + DataSource: { + S3DataSource: { + S3DataType: 'S3Prefix', + S3Uri: { + 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region'}, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/mytrainpath']], + }, + }, + }, + }, + ], + OutputDataConfig: { + S3OutputPath: { + 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region' }, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/myoutputpath']], + }, + }, + ResourceConfig: { + InstanceCount: 1, + InstanceType: 'ml.m4.xlarge', + VolumeSizeInGB: 10, + }, + RoleArn: { 'Fn::GetAtt': [ 'TrainSagemakerSagemakerRole89E8C593', 'Arn' ] }, + StoppingCondition: { + MaxRuntimeInSeconds: 3600, + }, + TrainingJobName: 'MyTrainJob', + }, + }); +}); + +test('Task throws if WAIT_FOR_TASK_TOKEN is supplied as service integration pattern', () => { + expect(() => { + new SageMakerCreateTrainingJob(stack, 'TrainSagemaker', { + integrationPattern: sfn.IntegrationPattern.WAIT_FOR_TASK_TOKEN, + trainingJobName: 'MyTrainJob', + algorithmSpecification: { + algorithmName: 'BlazingText', + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3Location: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'InputBucket', 'mybucket'), 'mytrainpath'), + }, + }, + }, + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'OutputBucket', 'mybucket'), 'myoutputpath'), + }, + }); + }).toThrow(/Unsupported service integration pattern. Supported Patterns: REQUEST_RESPONSE,RUN_JOB. Received: WAIT_FOR_TASK_TOKEN/i); +}); + +test('create complex training job', () => { + // WHEN + const kmsKey = new kms.Key(stack, 'Key'); + const vpc = new ec2.Vpc(stack, 'VPC'); + const securityGroup = new ec2.SecurityGroup(stack, 'SecurityGroup', { vpc, description: 'My SG' }); + securityGroup.addIngressRule(ec2.Peer.anyIpv4(), ec2.Port.tcp(22), 'allow ssh access from the world'); + + const role = new iam.Role(stack, 'Role', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicies: [ + iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess'), + ], + }); + + const trainTask = new SageMakerCreateTrainingJob(stack, 'TrainSagemaker', { + trainingJobName: 'MyTrainJob', + integrationPattern: sfn.IntegrationPattern.RUN_JOB, + role, + algorithmSpecification: { + algorithmName: 'BlazingText', + trainingInputMode: tasks.InputMode.FILE, + metricDefinitions: [ + { + name: 'mymetric', regex: 'regex_pattern', + }, + ], + }, + hyperparameters: { + lr: '0.1', + }, + inputDataConfig: [ + { + channelName: 'train', + contentType: 'image/jpeg', + compressionType: tasks.CompressionType.NONE, + recordWrapperType: tasks.RecordWrapperType.RECORD_IO, + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'InputBucketA', 'mybucket'), 'mytrainpath'), + }, + }, + }, + { + channelName: 'test', + contentType: 'image/jpeg', + compressionType: tasks.CompressionType.GZIP, + recordWrapperType: tasks.RecordWrapperType.RECORD_IO, + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'InputBucketB', 'mybucket'), 'mytestpath'), + }, + }, + }, + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'OutputBucket', 'mybucket'), 'myoutputpath'), + encryptionKey: kmsKey, + }, + resourceConfig: { + instanceCount: 1, + instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), + volumeSizeInGB: 50, + volumeEncryptionKey: kmsKey, + }, + stoppingCondition: { + maxRuntime: cdk.Duration.hours(1), + }, + tags: { + Project: 'MyProject', + }, + vpcConfig: { + vpc, + }, + }); + trainTask.addSecurityGroup(securityGroup); + + // THEN + expect(stack.resolve(trainTask.toStateJson())).toEqual({ + Type: 'Task', + Resource: { + 'Fn::Join': [ + '', + [ + 'arn:', + { + Ref: 'AWS::Partition', + }, + ':states:::sagemaker:createTrainingJob.sync', + ], + ], + }, + End: true, + Parameters: { + TrainingJobName: 'MyTrainJob', + RoleArn: { 'Fn::GetAtt': [ 'Role1ABCC5F0', 'Arn' ] }, + AlgorithmSpecification: { + TrainingInputMode: 'File', + AlgorithmName: 'BlazingText', + MetricDefinitions: [ + { Name: 'mymetric', Regex: 'regex_pattern' }, + ], + }, + HyperParameters: { + lr: '0.1', + }, + InputDataConfig: [ + { + ChannelName: 'train', + CompressionType: 'None', + RecordWrapperType: 'RecordIO', + ContentType: 'image/jpeg', + DataSource: { + S3DataSource: { + S3DataType: 'S3Prefix', + S3Uri: { + 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region'}, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/mytrainpath']], + }, + }, + }, + }, + { + ChannelName: 'test', + CompressionType: 'Gzip', + RecordWrapperType: 'RecordIO', + ContentType: 'image/jpeg', + DataSource: { + S3DataSource: { + S3DataType: 'S3Prefix', + S3Uri: { + 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region'}, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/mytestpath']], + }, + }, + }, + }, + ], + OutputDataConfig: { + S3OutputPath: { + 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region' }, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/myoutputpath']], + }, + KmsKeyId: { 'Fn::GetAtt': [ 'Key961B73FD', 'Arn' ] }, + }, + ResourceConfig: { + InstanceCount: 1, + InstanceType: 'ml.p3.2xlarge', + VolumeSizeInGB: 50, + VolumeKmsKeyId: { 'Fn::GetAtt': [ 'Key961B73FD', 'Arn' ] }, + }, + StoppingCondition: { + MaxRuntimeInSeconds: 3600, + }, + Tags: [ + { Key: 'Project', Value: 'MyProject' }, + ], + VpcConfig: { + SecurityGroupIds: [ + { 'Fn::GetAtt': [ 'TrainSagemakerTrainJobSecurityGroup7C858EB9', 'GroupId' ] }, + { 'Fn::GetAtt': [ 'SecurityGroupDD263621', 'GroupId' ] }, + ], + Subnets: [ + { Ref: 'VPCPrivateSubnet1Subnet8BCA10E0' }, + { Ref: 'VPCPrivateSubnet2SubnetCFCDAA7A' }, + ], + }, + }, + }); +}); + +test('pass param to training job', () => { + // WHEN + const role = new iam.Role(stack, 'Role', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicies: [ + iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess'), + ], + }); + + const task = new SageMakerCreateTrainingJob(stack, 'TrainSagemaker', { + trainingJobName: sfn.Data.stringAt('$.JobName'), + role, + algorithmSpecification: { + algorithmName: 'BlazingText', + trainingInputMode: tasks.InputMode.FILE, + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), + }, + }, + }, + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath'), + }, + resourceConfig: { + instanceCount: 1, + instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), + volumeSizeInGB: 50, + }, + stoppingCondition: { + maxRuntime: cdk.Duration.hours(1), + }, + }); + + // THEN + expect(stack.resolve(task.toStateJson())).toEqual({ + Type: 'Task', + Resource: { + 'Fn::Join': [ + '', + [ + 'arn:', + { + Ref: 'AWS::Partition', + }, + ':states:::sagemaker:createTrainingJob', + ], + ], + }, + End: true, + Parameters: { + 'TrainingJobName.$': '$.JobName', + 'RoleArn': { 'Fn::GetAtt': [ 'Role1ABCC5F0', 'Arn' ] }, + 'AlgorithmSpecification': { + TrainingInputMode: 'File', + AlgorithmName: 'BlazingText', + }, + 'InputDataConfig': [ + { + ChannelName: 'train', + DataSource: { + S3DataSource: { + 'S3DataType': 'S3Prefix', + 'S3Uri.$': '$.S3Bucket', + }, + }, + }, + ], + 'OutputDataConfig': { + S3OutputPath: { + 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region' }, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/myoutputpath']], + }, + }, + 'ResourceConfig': { + InstanceCount: 1, + InstanceType: 'ml.p3.2xlarge', + VolumeSizeInGB: 50, + }, + 'StoppingCondition': { + MaxRuntimeInSeconds: 3600, + }, + }, + }); +}); + +test('Cannot create a SageMaker train task with both algorithm name and image name missing', () => { + + expect(() => new SageMakerCreateTrainingJob(stack, 'SageMakerTrainingTask', { + trainingJobName: 'myTrainJob', + algorithmSpecification: {}, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), + }, + }, + }, + ], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'), + }, + })) + .toThrowError(/Must define either an algorithm name or training image URI in the algorithm specification/); +}); diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-transform-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-transform-job.test.ts new file mode 100644 index 0000000000000..21d700d6e06dc --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-transform-job.test.ts @@ -0,0 +1,243 @@ +import '@aws-cdk/assert/jest'; +import * as ec2 from '@aws-cdk/aws-ec2'; +import * as iam from '@aws-cdk/aws-iam'; +import * as kms from '@aws-cdk/aws-kms'; +import * as sfn from '@aws-cdk/aws-stepfunctions'; +import * as cdk from '@aws-cdk/core'; +import * as tasks from '../../lib'; +import { SageMakerCreateTransformJob } from '../../lib/sagemaker/create-transform-job'; + +let stack: cdk.Stack; +let role: iam.Role; + +beforeEach(() => { + // GIVEN + stack = new cdk.Stack(); + role = new iam.Role(stack, 'Role', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicies: [ + iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess'), + ], + }); +}); + +test('create basic transform job', () => { + // WHEN + const task = new SageMakerCreateTransformJob(stack, 'TransformTask', { + transformJobName: 'MyTransformJob', + modelName: 'MyModelName', + transformInput: { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/prefix', + }, + }, + }, + transformOutput: { + s3OutputPath: 's3://outputbucket/prefix', + }, + }); + + // THEN + expect(stack.resolve(task.toStateJson())).toEqual({ + Type: 'Task', + Resource: { + 'Fn::Join': [ + '', + [ + 'arn:', + { + Ref: 'AWS::Partition', + }, + ':states:::sagemaker:createTransformJob', + ], + ], + }, + End: true, + Parameters: { + TransformJobName: 'MyTransformJob', + ModelName: 'MyModelName', + TransformInput: { + DataSource: { + S3DataSource: { + S3Uri: 's3://inputbucket/prefix', + S3DataType: 'S3Prefix', + }, + }, + }, + TransformOutput: { + S3OutputPath: 's3://outputbucket/prefix', + }, + TransformResources: { + InstanceCount: 1, + InstanceType: 'ml.m4.xlarge', + }, + }, + }); +}); + +test('Task throws if WAIT_FOR_TASK_TOKEN is supplied as service integration pattern', () => { + expect(() => { + new SageMakerCreateTransformJob(stack, 'TransformTask', { + integrationPattern: sfn.IntegrationPattern.WAIT_FOR_TASK_TOKEN, + transformJobName: 'MyTransformJob', + modelName: 'MyModelName', + transformInput: { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/prefix', + }, + }, + }, + transformOutput: { + s3OutputPath: 's3://outputbucket/prefix', + }, + }); + }).toThrow(/Unsupported service integration pattern. Supported Patterns: REQUEST_RESPONSE,RUN_JOB. Received: WAIT_FOR_TASK_TOKEN/); +}); + +test('create complex transform job', () => { + // WHEN + const kmsKey = new kms.Key(stack, 'Key'); + const task = new SageMakerCreateTransformJob(stack, 'TransformTask', { + transformJobName: 'MyTransformJob', + modelName: 'MyModelName', + integrationPattern: sfn.IntegrationPattern.RUN_JOB, + role, + transformInput: { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/prefix', + s3DataType: tasks.S3DataType.S3_PREFIX, + }, + }, + }, + transformOutput: { + s3OutputPath: 's3://outputbucket/prefix', + encryptionKey: kmsKey, + }, + transformResources: { + instanceCount: 1, + instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), + volumeKmsKeyId: kmsKey, + }, + tags: { + Project: 'MyProject', + }, + batchStrategy: tasks.BatchStrategy.MULTI_RECORD, + environment: { + SOMEVAR: 'myvalue', + }, + maxConcurrentTransforms: 3, + maxPayloadInMB: 100, + }); + + // THEN + expect(stack.resolve(task.toStateJson())).toEqual({ + Type: 'Task', + Resource: { + 'Fn::Join': [ + '', + [ + 'arn:', + { + Ref: 'AWS::Partition', + }, + ':states:::sagemaker:createTransformJob.sync', + ], + ], + }, + End: true, + Parameters: { + TransformJobName: 'MyTransformJob', + ModelName: 'MyModelName', + TransformInput: { + DataSource: { + S3DataSource: { + S3Uri: 's3://inputbucket/prefix', + S3DataType: 'S3Prefix', + }, + }, + }, + TransformOutput: { + S3OutputPath: 's3://outputbucket/prefix', + KmsKeyId: { 'Fn::GetAtt': [ 'Key961B73FD', 'Arn' ] }, + }, + TransformResources: { + InstanceCount: 1, + InstanceType: 'ml.p3.2xlarge', + VolumeKmsKeyId: { 'Fn::GetAtt': [ 'Key961B73FD', 'Arn' ] }, + }, + Tags: [ + { Key: 'Project', Value: 'MyProject' }, + ], + MaxConcurrentTransforms: 3, + MaxPayloadInMB: 100, + Environment: { + SOMEVAR: 'myvalue', + }, + BatchStrategy: 'MultiRecord', + }, + }); +}); + +test('pass param to transform job', () => { + // WHEN + const task = new SageMakerCreateTransformJob(stack, 'TransformTask', { + transformJobName: sfn.Data.stringAt('$.TransformJobName'), + modelName: sfn.Data.stringAt('$.ModelName'), + role, + transformInput: { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/prefix', + s3DataType: tasks.S3DataType.S3_PREFIX, + }, + }, + }, + transformOutput: { + s3OutputPath: 's3://outputbucket/prefix', + }, + transformResources: { + instanceCount: 1, + instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), + }, + }); + + // THEN + expect(stack.resolve(task.toStateJson())).toEqual({ + Type: 'Task', + Resource: { + 'Fn::Join': [ + '', + [ + 'arn:', + { + Ref: 'AWS::Partition', + }, + ':states:::sagemaker:createTransformJob', + ], + ], + }, + End: true, + Parameters: { + 'TransformJobName.$': '$.TransformJobName', + 'ModelName.$': '$.ModelName', + 'TransformInput': { + DataSource: { + S3DataSource: { + S3Uri: 's3://inputbucket/prefix', + S3DataType: 'S3Prefix', + }, + }, + }, + 'TransformOutput': { + S3OutputPath: 's3://outputbucket/prefix', + }, + 'TransformResources': { + InstanceCount: 1, + InstanceType: 'ml.p3.2xlarge', + }, + }, + }); +}); diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.expected.json b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.expected.json new file mode 100644 index 0000000000000..b8af9b15f61bf --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.expected.json @@ -0,0 +1,402 @@ +{ + "Resources": { + "EncryptionKey1B843E66": { + "Type": "AWS::KMS::Key", + "Properties": { + "KeyPolicy": { + "Statement": [ + { + "Action": [ + "kms:Create*", + "kms:Describe*", + "kms:Enable*", + "kms:List*", + "kms:Put*", + "kms:Update*", + "kms:Revoke*", + "kms:Disable*", + "kms:Get*", + "kms:Delete*", + "kms:ScheduleKeyDeletion", + "kms:CancelKeyDeletion", + "kms:GenerateDataKey", + "kms:TagResource", + "kms:UntagResource" + ], + "Effect": "Allow", + "Principal": { + "AWS": { + "Fn::Join": [ + "", + [ + "arn:", + { + "Ref": "AWS::Partition" + }, + ":iam::", + { + "Ref": "AWS::AccountId" + }, + ":root" + ] + ] + } + }, + "Resource": "*" + }, + { + "Action": [ + "kms:Decrypt", + "kms:DescribeKey" + ], + "Effect": "Allow", + "Principal": { + "AWS": { + "Fn::GetAtt": [ + "TrainTaskSagemakerRole0A9B1CDD", + "Arn" + ] + } + }, + "Resource": "*" + }, + { + "Action": [ + "kms:Encrypt", + "kms:ReEncrypt*", + "kms:GenerateDataKey*" + ], + "Effect": "Allow", + "Principal": { + "AWS": { + "Fn::GetAtt": [ + "TrainTaskSagemakerRole0A9B1CDD", + "Arn" + ] + } + }, + "Resource": "*" + } + ], + "Version": "2012-10-17" + } + }, + "UpdateReplacePolicy": "Delete", + "DeletionPolicy": "Delete" + }, + "TrainingData3FDB6D34": { + "Type": "AWS::S3::Bucket", + "Properties": { + "BucketEncryption": { + "ServerSideEncryptionConfiguration": [ + { + "ServerSideEncryptionByDefault": { + "KMSMasterKeyID": { + "Fn::GetAtt": [ + "EncryptionKey1B843E66", + "Arn" + ] + }, + "SSEAlgorithm": "aws:kms" + } + } + ] + } + }, + "UpdateReplacePolicy": "Delete", + "DeletionPolicy": "Delete" + }, + "TrainTaskSagemakerRole0A9B1CDD": { + "Type": "AWS::IAM::Role", + "Properties": { + "AssumeRolePolicyDocument": { + "Statement": [ + { + "Action": "sts:AssumeRole", + "Effect": "Allow", + "Principal": { + "Service": "sagemaker.amazonaws.com" + } + } + ], + "Version": "2012-10-17" + }, + "Policies": [ + { + "PolicyDocument": { + "Statement": [ + { + "Action": [ + "cloudwatch:PutMetricData", + "logs:CreateLogStream", + "logs:PutLogEvents", + "logs:CreateLogGroup", + "logs:DescribeLogStreams", + "ecr:GetAuthorizationToken" + ], + "Effect": "Allow", + "Resource": "*" + } + ], + "Version": "2012-10-17" + }, + "PolicyName": "CreateTrainingJob" + } + ] + } + }, + "TrainTaskSagemakerRoleDefaultPolicyA28F72FA": { + "Type": "AWS::IAM::Policy", + "Properties": { + "PolicyDocument": { + "Statement": [ + { + "Action": [ + "s3:GetObject*", + "s3:GetBucket*", + "s3:List*" + ], + "Effect": "Allow", + "Resource": [ + { + "Fn::GetAtt": [ + "TrainingData3FDB6D34", + "Arn" + ] + }, + { + "Fn::Join": [ + "", + [ + { + "Fn::GetAtt": [ + "TrainingData3FDB6D34", + "Arn" + ] + }, + "/data/*" + ] + ] + } + ] + }, + { + "Action": [ + "kms:Decrypt", + "kms:DescribeKey" + ], + "Effect": "Allow", + "Resource": { + "Fn::GetAtt": [ + "EncryptionKey1B843E66", + "Arn" + ] + } + }, + { + "Action": [ + "s3:DeleteObject*", + "s3:PutObject*", + "s3:Abort*" + ], + "Effect": "Allow", + "Resource": [ + { + "Fn::GetAtt": [ + "TrainingData3FDB6D34", + "Arn" + ] + }, + { + "Fn::Join": [ + "", + [ + { + "Fn::GetAtt": [ + "TrainingData3FDB6D34", + "Arn" + ] + }, + "/result/*" + ] + ] + } + ] + }, + { + "Action": [ + "kms:Encrypt", + "kms:ReEncrypt*", + "kms:GenerateDataKey*" + ], + "Effect": "Allow", + "Resource": { + "Fn::GetAtt": [ + "EncryptionKey1B843E66", + "Arn" + ] + } + } + ], + "Version": "2012-10-17" + }, + "PolicyName": "TrainTaskSagemakerRoleDefaultPolicyA28F72FA", + "Roles": [ + { + "Ref": "TrainTaskSagemakerRole0A9B1CDD" + } + ] + } + }, + "StateMachineRoleB840431D": { + "Type": "AWS::IAM::Role", + "Properties": { + "AssumeRolePolicyDocument": { + "Statement": [ + { + "Action": "sts:AssumeRole", + "Effect": "Allow", + "Principal": { + "Service": { + "Fn::Join": [ + "", + [ + "states.", + { + "Ref": "AWS::Region" + }, + ".amazonaws.com" + ] + ] + } + } + } + ], + "Version": "2012-10-17" + } + } + }, + "StateMachineRoleDefaultPolicyDF1E6607": { + "Type": "AWS::IAM::Policy", + "Properties": { + "PolicyDocument": { + "Statement": [ + { + "Action": [ + "sagemaker:CreateTrainingJob", + "sagemaker:DescribeTrainingJob", + "sagemaker:StopTrainingJob" + ], + "Effect": "Allow", + "Resource": { + "Fn::Join": [ + "", + [ + "arn:", + { + "Ref": "AWS::Partition" + }, + ":sagemaker:", + { + "Ref": "AWS::Region" + }, + ":", + { + "Ref": "AWS::AccountId" + }, + ":training-job/MyTrainingJob*" + ] + ] + } + }, + { + "Action": "sagemaker:ListTags", + "Effect": "Allow", + "Resource": "*" + }, + { + "Action": "iam:PassRole", + "Condition": { + "StringEquals": { + "iam:PassedToService": "sagemaker.amazonaws.com" + } + }, + "Effect": "Allow", + "Resource": { + "Fn::GetAtt": [ + "TrainTaskSagemakerRole0A9B1CDD", + "Arn" + ] + } + } + ], + "Version": "2012-10-17" + }, + "PolicyName": "StateMachineRoleDefaultPolicyDF1E6607", + "Roles": [ + { + "Ref": "StateMachineRoleB840431D" + } + ] + } + }, + "StateMachine2E01A3A5": { + "Type": "AWS::StepFunctions::StateMachine", + "Properties": { + "DefinitionString": { + "Fn::Join": [ + "", + [ + "{\"StartAt\":\"TrainTask\",\"States\":{\"TrainTask\":{\"End\":true,\"Type\":\"Task\",\"Resource\":\"arn:", + { + "Ref": "AWS::Partition" + }, + ":states:::sagemaker:createTrainingJob\",\"Parameters\":{\"TrainingJobName\":\"MyTrainingJob\",\"RoleArn\":\"", + { + "Fn::GetAtt": [ + "TrainTaskSagemakerRole0A9B1CDD", + "Arn" + ] + }, + "\",\"AlgorithmSpecification\":{\"TrainingInputMode\":\"File\",\"AlgorithmName\":\"GRADIENT_ASCENT\"},\"InputDataConfig\":[{\"ChannelName\":\"InputData\",\"DataSource\":{\"S3DataSource\":{\"S3Uri\":\"https://s3.", + { + "Ref": "AWS::Region" + }, + ".", + { + "Ref": "AWS::URLSuffix" + }, + "/", + { + "Ref": "TrainingData3FDB6D34" + }, + "/data/\",\"S3DataType\":\"S3Prefix\"}}}],\"OutputDataConfig\":{\"S3OutputPath\":\"https://s3.", + { + "Ref": "AWS::Region" + }, + ".", + { + "Ref": "AWS::URLSuffix" + }, + "/", + { + "Ref": "TrainingData3FDB6D34" + }, + "/result/\"},\"ResourceConfig\":{\"InstanceCount\":1,\"InstanceType\":\"ml.m4.xlarge\",\"VolumeSizeInGB\":10},\"StoppingCondition\":{\"MaxRuntimeInSeconds\":3600}}}}}" + ] + ] + }, + "RoleArn": { + "Fn::GetAtt": [ + "StateMachineRoleB840431D", + "Arn" + ] + } + }, + "DependsOn": [ + "StateMachineRoleDefaultPolicyDF1E6607", + "StateMachineRoleB840431D" + ] + } + } +} \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.ts new file mode 100644 index 0000000000000..74d89a5a5674c --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.ts @@ -0,0 +1,33 @@ +import { Key } from '@aws-cdk/aws-kms'; +import { Bucket, BucketEncryption } from '@aws-cdk/aws-s3'; +import { StateMachine } from '@aws-cdk/aws-stepfunctions'; +import { App, RemovalPolicy, Stack } from '@aws-cdk/core'; +import { S3Location } from '../../lib'; +import { SageMakerCreateTrainingJob } from '../../lib/sagemaker/create-training-job'; + +const app = new App(); +const stack = new Stack(app, 'integ-stepfunctions-sagemaker'); + +const encryptionKey = new Key(stack, 'EncryptionKey', { + removalPolicy: RemovalPolicy.DESTROY, +}); +const trainingData = new Bucket(stack, 'TrainingData', { + encryption: BucketEncryption.KMS, + encryptionKey, + removalPolicy: RemovalPolicy.DESTROY, +}); + +new StateMachine(stack, 'StateMachine', { + definition: new SageMakerCreateTrainingJob(stack, 'TrainTask', { + algorithmSpecification: { + algorithmName: 'GRADIENT_ASCENT', + }, + inputDataConfig: [{ channelName: 'InputData', dataSource: { + s3DataSource: { + s3Location: S3Location.fromBucket(trainingData, 'data/'), + }, + } }], + outputDataConfig: { s3OutputLocation: S3Location.fromBucket(trainingData, 'result/') }, + trainingJobName: 'MyTrainingJob', + }), +}); From 5e9c6fe83a1107f9851118829cc8d54274458c48 Mon Sep 17 00:00:00 2001 From: Shiv Lakshminarayan Date: Tue, 2 Jun 2020 23:26:22 -0700 Subject: [PATCH 2/9] use a single base-types for now, can rename later --- .../lib/sagemaker/base-types.ts | 758 ------------------ .../lib/sagemaker/create-training-job.ts | 210 ++--- .../sagemaker/sagemaker-task-base-types.ts | 2 +- 3 files changed, 111 insertions(+), 859 deletions(-) delete mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts deleted file mode 100644 index bbb9908fd58ef..0000000000000 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts +++ /dev/null @@ -1,758 +0,0 @@ -import * as ec2 from '@aws-cdk/aws-ec2'; -import * as ecr from '@aws-cdk/aws-ecr'; -import { DockerImageAsset, DockerImageAssetProps } from '@aws-cdk/aws-ecr-assets'; -import * as iam from '@aws-cdk/aws-iam'; -import * as kms from '@aws-cdk/aws-kms'; -import * as s3 from '@aws-cdk/aws-s3'; -import * as sfn from '@aws-cdk/aws-stepfunctions'; -import { Construct, Duration } from '@aws-cdk/core'; - -/** - * Task to train a machine learning model using Amazon SageMaker - * @experimental - */ -export interface ISageMakerTask extends sfn.TaskStateBase, iam.IGrantable {} - -/** - * Specify the training algorithm and algorithm-specific metadata - * @experimental - */ -export interface AlgorithmSpecification { - - /** - * Name of the algorithm resource to use for the training job. - * This must be an algorithm resource that you created or subscribe to on AWS Marketplace. - * If you specify a value for this parameter, you can't specify a value for TrainingImage. - * - * @default - No algorithm is specified - */ - readonly algorithmName?: string; - - /** - * List of metric definition objects. Each object specifies the metric name and regular expressions used to parse algorithm logs. - * - * @default - No metrics - */ - readonly metricDefinitions?: MetricDefinition[]; - - /** - * Registry path of the Docker image that contains the training algorithm. - * - * @default - No Docker image is specified - */ - readonly trainingImage?: DockerImage; - - /** - * Input mode that the algorithm supports. - * - * @default 'File' mode - */ - readonly trainingInputMode?: InputMode; -} - -/** - * Describes the training, validation or test dataset and the Amazon S3 location where it is stored. - * - * @experimental - */ -export interface Channel { - - /** - * Name of the channel - */ - readonly channelName: string; - - /** - * Compression type if training data is compressed - * - * @default - None - */ - readonly compressionType?: CompressionType; - - /** - * The MIME type of the data. - * - * @default - None - */ - readonly contentType?: string; - - /** - * Location of the channel data. - */ - readonly dataSource: DataSource; - - /** - * Input mode to use for the data channel in a training job. - * - * @default - None - */ - readonly inputMode?: InputMode; - - /** - * Specify RecordIO as the value when input data is in raw format but the training algorithm requires the RecordIO format. - * In this case, Amazon SageMaker wraps each individual S3 object in a RecordIO record. - * If the input data is already in RecordIO format, you don't need to set this attribute. - * - * @default - None - */ - readonly recordWrapperType?: RecordWrapperType; - - /** - * Shuffle config option for input data in a channel. - * - * @default - None - */ - readonly shuffleConfig?: ShuffleConfig; -} - -/** - * Configuration for a shuffle option for input data in a channel. - * - * @experimental - */ -export interface ShuffleConfig { - /** - * Determines the shuffling order. - */ - readonly seed: number; -} - -/** - * Location of the channel data. - * - * @experimental - */ -export interface DataSource { - /** - * S3 location of the data source that is associated with a channel. - */ - readonly s3DataSource: S3DataSource; -} - -/** - * S3 location of the channel data. - * - * @see https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_S3DataSource.html - * - * @experimental - */ -export interface S3DataSource { - /** - * List of one or more attribute names to use that are found in a specified augmented manifest file. - * - * @default - No attribute names - */ - readonly attributeNames?: string[]; - - /** - * S3 Data Distribution Type - * - * @default - None - */ - readonly s3DataDistributionType?: S3DataDistributionType; - - /** - * S3 Data Type - * - * @default S3_PREFIX - */ - readonly s3DataType?: S3DataType; - - /** - * S3 Uri - */ - readonly s3Location: S3Location; -} - -/** - * Configures the S3 bucket where SageMaker will save the result of model training - * @experimental - */ -export interface OutputDataConfig { - /** - * Optional KMS encryption key that Amazon SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side encryption. - * - * @default - Amazon SageMaker uses the default KMS key for Amazon S3 for your role's account - */ - readonly encryptionKey?: kms.IKey; - - /** - * Identifies the S3 path where you want Amazon SageMaker to store the model artifacts. - */ - readonly s3OutputLocation: S3Location; -} - -/** - * Specifies a limit to how long a model training job can run. - * When the job reaches the time limit, Amazon SageMaker ends the training job. - * - * @experimental - */ -export interface StoppingCondition { - /** - * The maximum length of time, in seconds, that the training or compilation job can run. - * - * @default - 1 hour - */ - readonly maxRuntime?: Duration; -} - -/** - * Specifies the resources, ML compute instances, and ML storage volumes to deploy for model training. - * - * @experimental - */ -export interface ResourceConfig { - - /** - * The number of ML compute instances to use. - * - * @default 1 instance. - */ - readonly instanceCount: number; - - /** - * ML compute instance type. - * - * @default is the 'm4.xlarge' instance type. - */ - readonly instanceType: ec2.InstanceType; - - /** - * KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) that run the training job. - * - * @default - Amazon SageMaker uses the default KMS key for Amazon S3 for your role's account - */ - readonly volumeEncryptionKey?: kms.IKey; - - /** - * Size of the ML storage volume that you want to provision. - * - * @default 10 GB EBS volume. - */ - readonly volumeSizeInGB: number; -} - -/** - * Specifies the VPC that you want your Amazon SageMaker training job to connect to. - * - * @experimental - */ -export interface VpcConfig { - /** - * VPC - */ - readonly vpc: ec2.IVpc; - - /** - * VPC subnets. - * - * @default - Private Subnets are selected - */ - readonly subnets?: ec2.SubnetSelection; -} - -/** - * Specifies the metric name and regular expressions used to parse algorithm logs. - * - * @experimental - */ -export interface MetricDefinition { - - /** - * Name of the metric. - */ - readonly name: string; - - /** - * Regular expression that searches the output of a training job and gets the value of the metric. - */ - readonly regex: string; -} - -/** - * Stores information about the location of an object in Amazon S3 - * - * @experimental - */ -export interface S3LocationConfig { - - /** - * Uniquely identifies the resource in Amazon S3 - */ - readonly uri: string; -} - -/** - * Constructs `IS3Location` objects. - * - * @experimental - */ -export abstract class S3Location { - /** - * An `IS3Location` built with a determined bucket and key prefix. - * - * @param bucket is the bucket where the objects are to be stored. - * @param keyPrefix is the key prefix used by the location. - */ - public static fromBucket(bucket: s3.IBucket, keyPrefix: string): S3Location { - return new StandardS3Location({ bucket, keyPrefix, uri: bucket.urlForObject(keyPrefix) }); - } - - /** - * An `IS3Location` determined fully by a JSON Path from the task input. - * - * Due to the dynamic nature of those locations, the IAM grants that will be set by `grantRead` and `grantWrite` - * apply to the `*` resource. - * - * @param expression the JSON expression resolving to an S3 location URI. - */ - public static fromJsonExpression(expression: string): S3Location { - return new StandardS3Location({ uri: sfn.Data.stringAt(expression) }); - } - - /** - * Called when the S3Location is bound to a StepFunctions task. - */ - public abstract bind(task: ISageMakerTask, opts: S3LocationBindOptions): S3LocationConfig; -} - -/** - * Options for binding an S3 Location. - * - * @experimental - */ -export interface S3LocationBindOptions { - /** - * Allow reading from the S3 Location. - * - * @default false - */ - readonly forReading?: boolean; - - /** - * Allow writing to the S3 Location. - * - * @default false - */ - readonly forWriting?: boolean; -} - -/** - * Configuration for a using Docker image. - * - * @experimental - */ -export interface DockerImageConfig { - /** - * The fully qualified URI of the Docker image. - */ - readonly imageUri: string; -} - -/** - * Creates `IDockerImage` instances. - * - * @experimental - */ -export abstract class DockerImage { - /** - * Reference a Docker image stored in an ECR repository. - * - * @param repository the ECR repository where the image is hosted. - * @param tag an optional `tag` - */ - public static fromEcrRepository(repository: ecr.IRepository, tag: string = 'latest'): DockerImage { - return new StandardDockerImage({ repository, imageUri: repository.repositoryUriForTag(tag) }); - } - - /** - * Reference a Docker image which URI is obtained from the task's input. - * - * @param expression the JSON path expression with the task input. - * @param allowAnyEcrImagePull whether ECR access should be permitted (set to `false` if the image will never be in ECR). - */ - public static fromJsonExpression(expression: string, allowAnyEcrImagePull = true): DockerImage { - return new StandardDockerImage({ imageUri: expression, allowAnyEcrImagePull }); - } - - /** - * Reference a Docker image by it's URI. - * - * When referencing ECR images, prefer using `inEcr`. - * - * @param imageUri the URI to the docker image. - */ - public static fromRegistry(imageUri: string): DockerImage { - return new StandardDockerImage({ imageUri }); - } - - /** - * Reference a Docker image that is provided as an Asset in the current app. - * - * @param scope the scope in which to create the Asset. - * @param id the ID for the asset in the construct tree. - * @param props the configuration props of the asset. - */ - public static fromAsset(scope: Construct, id: string, props: DockerImageAssetProps): DockerImage { - const asset = new DockerImageAsset(scope, id, props); - return new StandardDockerImage({ repository: asset.repository, imageUri: asset.imageUri }); - } - - /** - * Called when the image is used by a SageMaker task. - */ - public abstract bind(task: ISageMakerTask): DockerImageConfig; -} - -/** - * S3 Data Type. - * - * @experimental - */ -export enum S3DataType { - /** - * Manifest File Data Type - */ - MANIFEST_FILE = 'ManifestFile', - - /** - * S3 Prefix Data Type - */ - S3_PREFIX = 'S3Prefix', - - /** - * Augmented Manifest File Data Type - */ - AUGMENTED_MANIFEST_FILE = 'AugmentedManifestFile' -} - -/** - * S3 Data Distribution Type. - * - * @experimental - */ -export enum S3DataDistributionType { - /** - * Fully replicated S3 Data Distribution Type - */ - FULLY_REPLICATED = 'FullyReplicated', - - /** - * Sharded By S3 Key Data Distribution Type - */ - SHARDED_BY_S3_KEY = 'ShardedByS3Key' -} - -/** - * Define the format of the input data. - * - * @experimental - */ -export enum RecordWrapperType { - /** - * None record wrapper type - */ - NONE = 'None', - - /** - * RecordIO record wrapper type - */ - RECORD_IO = 'RecordIO' -} - -/** - * Input mode that the algorithm supports. - * - * @experimental - */ -export enum InputMode { - /** - * Pipe mode - */ - PIPE = 'Pipe', - - /** - * File mode. - */ - FILE = 'File' -} - -/** - * Compression type of the data. - * - * @experimental - */ -export enum CompressionType { - /** - * None compression type - */ - NONE = 'None', - - /** - * Gzip compression type - */ - GZIP = 'Gzip' -} - -// -// Create Transform Job types -// - -/** - * Dataset to be transformed and the Amazon S3 location where it is stored. - * - * @experimental - */ -export interface TransformInput { - - /** - * The compression type of the transform data. - * - * @default NONE - */ - readonly compressionType?: CompressionType; - - /** - * Multipurpose internet mail extension (MIME) type of the data. - * - * @default - None - */ - readonly contentType?: string; - - /** - * S3 location of the channel data - */ - readonly transformDataSource: TransformDataSource; - - /** - * Method to use to split the transform job's data files into smaller batches. - * - * @default NONE - */ - readonly splitType?: SplitType; -} - -/** - * S3 location of the input data that the model can consume. - * - * @experimental - */ -export interface TransformDataSource { - - /** - * S3 location of the input data - */ - readonly s3DataSource: TransformS3DataSource; -} - -/** - * Location of the channel data. - * - * @experimental - */ -export interface TransformS3DataSource { - - /** - * S3 Data Type - * - * @default 'S3Prefix' - */ - readonly s3DataType?: S3DataType; - - /** - * Identifies either a key name prefix or a manifest. - */ - readonly s3Uri: string; -} - -/** - * S3 location where you want Amazon SageMaker to save the results from the transform job. - * - * @experimental - */ -export interface TransformOutput { - - /** - * MIME type used to specify the output data. - * - * @default - None - */ - readonly accept?: string; - - /** - * Defines how to assemble the results of the transform job as a single S3 object. - * - * @default - None - */ - readonly assembleWith?: AssembleWith; - - /** - * AWS KMS key that Amazon SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side encryption. - * - * @default - default KMS key for Amazon S3 for your role's account. - */ - readonly encryptionKey?: kms.IKey; - - /** - * S3 path where you want Amazon SageMaker to store the results of the transform job. - */ - readonly s3OutputPath: string; -} - -/** - * ML compute instances for the transform job. - * - * @experimental - */ -export interface TransformResources { - - /** - * Number of ML compute instances to use in the transform job - */ - readonly instanceCount: number; - - /** - * ML compute instance type for the transform job. - */ - readonly instanceType: ec2.InstanceType; - - /** - * AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s). - * - * @default - None - */ - readonly volumeKmsKeyId?: kms.Key; -} - -/** - * Specifies the number of records to include in a mini-batch for an HTTP inference request. - * - * @experimental - */ -export enum BatchStrategy { - - /** - * Fits multiple records in a mini-batch. - */ - MULTI_RECORD = 'MultiRecord', - - /** - * Use a single record when making an invocation request. - */ - SINGLE_RECORD = 'SingleRecord' -} - -/** - * Method to use to split the transform job's data files into smaller batches. - * - * @experimental - */ -export enum SplitType { - - /** - * Input data files are not split, - */ - NONE = 'None', - - /** - * Split records on a newline character boundary. - */ - LINE = 'Line', - - /** - * Split using MXNet RecordIO format. - */ - RECORD_IO = 'RecordIO', - - /** - * Split using TensorFlow TFRecord format. - */ - TF_RECORD = 'TFRecord' -} - -/** - * How to assemble the results of the transform job as a single S3 object. - * - * @experimental - */ -export enum AssembleWith { - - /** - * Concatenate the results in binary format. - */ - NONE = 'None', - - /** - * Add a newline character at the end of every transformed record. - */ - LINE = 'Line' - -} - -class StandardDockerImage extends DockerImage { - private readonly allowAnyEcrImagePull: boolean; - private readonly imageUri: string; - private readonly repository?: ecr.IRepository; - - constructor(opts: { allowAnyEcrImagePull?: boolean, imageUri: string, repository?: ecr.IRepository }) { - super(); - - this.allowAnyEcrImagePull = !!opts.allowAnyEcrImagePull; - this.imageUri = opts.imageUri; - this.repository = opts.repository; - } - - public bind(task: ISageMakerTask): DockerImageConfig { - if (this.repository) { - this.repository.grantPull(task); - } - if (this.allowAnyEcrImagePull) { - task.grantPrincipal.addToPolicy(new iam.PolicyStatement({ - actions: [ - 'ecr:BatchCheckLayerAvailability', - 'ecr:GetDownloadUrlForLayer', - 'ecr:BatchGetImage', - ], - resources: ['*'], - })); - } - return { - imageUri: this.imageUri, - }; - } -} - -class StandardS3Location extends S3Location { - private readonly bucket?: s3.IBucket; - private readonly keyGlob: string; - private readonly uri: string; - - constructor(opts: { bucket?: s3.IBucket, keyPrefix?: string, uri: string }) { - super(); - this.bucket = opts.bucket; - this.keyGlob = `${opts.keyPrefix || ''}*`; - this.uri = opts.uri; - } - - public bind(task: ISageMakerTask, opts: S3LocationBindOptions): S3LocationConfig { - if (this.bucket) { - if (opts.forReading) { - this.bucket.grantRead(task, this.keyGlob); - } - if (opts.forWriting) { - this.bucket.grantWrite(task, this.keyGlob); - } - } else { - const actions = new Array(); - if (opts.forReading) { - actions.push('s3:GetObject', 's3:ListBucket'); - } - if (opts.forWriting) { - actions.push('s3:PutObject'); - } - task.grantPrincipal.addToPolicy(new iam.PolicyStatement({ actions, resources: ['*'] })); - } - return { uri: this.uri }; - } -} diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts index 5a66d346c2c5e..fedb975af4aed 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts @@ -3,8 +3,16 @@ import * as iam from '@aws-cdk/aws-iam'; import * as sfn from '@aws-cdk/aws-stepfunctions'; import { Construct, Duration, Lazy, Stack } from '@aws-cdk/core'; import { integrationResourceArn, validatePatternSupported } from '../private/task-utils'; -import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig, - S3DataType, StoppingCondition, VpcConfig } from './sagemaker-task-base-types'; +import { + AlgorithmSpecification, + Channel, + InputMode, + OutputDataConfig, + ResourceConfig, + S3DataType, + StoppingCondition, + VpcConfig, +} from './sagemaker-task-base-types'; /** * Properties for creating an Amazon SageMaker training job @@ -12,7 +20,6 @@ import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceC * @experimental */ export interface SageMakerCreateTrainingJobProps extends sfn.TaskStateBaseProps { - /** * Training Job Name. */ @@ -24,7 +31,7 @@ export interface SageMakerCreateTrainingJobProps extends sfn.TaskStateBaseProps * * See https://docs.aws.amazon.com/fr_fr/sagemaker/latest/dg/sagemaker-roles.html#sagemaker-roles-createtrainingjob-perms * - * @default - a role with appropriate permissions will be created. + * @default - a role will be created. */ readonly role?: iam.IRole; @@ -40,7 +47,7 @@ export interface SageMakerCreateTrainingJobProps extends sfn.TaskStateBaseProps * * @default - No hyperparameters */ - readonly hyperparameters?: {[key: string]: any}; + readonly hyperparameters?: { [key: string]: any }; /** * Describes the various datasets (e.g. train, validation, test) and the Amazon S3 location where stored. @@ -52,7 +59,7 @@ export interface SageMakerCreateTrainingJobProps extends sfn.TaskStateBaseProps * * @default - No tags */ - readonly tags?: {[key: string]: string}; + readonly tags?: { [key: string]: string }; /** * Identifies the Amazon S3 location where you want Amazon SageMaker to save the results of model training. @@ -87,7 +94,6 @@ export interface SageMakerCreateTrainingJobProps extends sfn.TaskStateBaseProps * @experimental */ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam.IGrantable, ec2.IConnectable { - private static readonly SUPPORTED_INTEGRATION_PATTERNS: sfn.IntegrationPattern[] = [ sfn.IntegrationPattern.REQUEST_RESPONSE, sfn.IntegrationPattern.RUN_JOB, @@ -148,20 +154,19 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam }; // check that either algorithm name or image is defined - if ((!props.algorithmSpecification.algorithmName) && (!props.algorithmSpecification.trainingImage)) { + if (!props.algorithmSpecification.algorithmName && !props.algorithmSpecification.trainingImage) { throw new Error('Must define either an algorithm name or training image URI in the algorithm specification'); } // set the input mode to 'File' if not defined - this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ? - ( props.algorithmSpecification ) : - ( { ...props.algorithmSpecification, trainingInputMode: InputMode.FILE } ); + this.algorithmSpecification = props.algorithmSpecification.trainingInputMode + ? props.algorithmSpecification + : { ...props.algorithmSpecification, trainingInputMode: InputMode.FILE }; // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined - this.inputDataConfig = props.inputDataConfig.map(config => { + this.inputDataConfig = props.inputDataConfig.map((config) => { if (!config.dataSource.s3DataSource.s3DataType) { - return Object.assign({}, config, { dataSource: { s3DataSource: - { ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } } }); + return Object.assign({}, config, { dataSource: { s3DataSource: { ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } } }); } else { return config; } @@ -170,8 +175,7 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam // add the security groups to the connections object if (props.vpcConfig) { this.vpc = props.vpcConfig.vpc; - this.subnets = (props.vpcConfig.subnets) ? - (this.vpc.selectSubnets(props.vpcConfig.subnets).subnetIds) : this.vpc.selectSubnets().subnetIds; + this.subnets = props.vpcConfig.subnets ? this.vpc.selectSubnets(props.vpcConfig.subnets).subnetIds : this.vpc.selectSubnets().subnetIds; } this.taskPolicies = this.makePolicyStatements(); @@ -207,87 +211,83 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam } public renderTask(): any { - return this.bind(); - } - - public bind(): any { return { Resource: integrationResourceArn('sagemaker', 'createTrainingJob', this.integrationPattern), Parameters: sfn.FieldUtils.renderObject(this.renderParameters()), }; } - private renderParameters(): {[key: string]: any} { + private renderParameters(): { [key: string]: any } { return { TrainingJobName: this.props.trainingJobName, RoleArn: this._role!.roleArn, - ...(this.renderAlgorithmSpecification(this.algorithmSpecification)), - ...(this.renderInputDataConfig(this.inputDataConfig)), - ...(this.renderOutputDataConfig(this.props.outputDataConfig)), - ...(this.renderResourceConfig(this.resourceConfig)), - ...(this.renderStoppingCondition(this.stoppingCondition)), - ...(this.renderHyperparameters(this.props.hyperparameters)), - ...(this.renderTags(this.props.tags)), - ...(this.renderVpcConfig(this.props.vpcConfig)), + ...this.renderAlgorithmSpecification(this.algorithmSpecification), + ...this.renderInputDataConfig(this.inputDataConfig), + ...this.renderOutputDataConfig(this.props.outputDataConfig), + ...this.renderResourceConfig(this.resourceConfig), + ...this.renderStoppingCondition(this.stoppingCondition), + ...this.renderHyperparameters(this.props.hyperparameters), + ...this.renderTags(this.props.tags), + ...this.renderVpcConfig(this.props.vpcConfig), }; } - private renderAlgorithmSpecification(spec: AlgorithmSpecification): {[key: string]: any} { + private renderAlgorithmSpecification(spec: AlgorithmSpecification): { [key: string]: any } { return { AlgorithmSpecification: { TrainingInputMode: spec.trainingInputMode, - ...(spec.trainingImage) ? { TrainingImage: spec.trainingImage.bind(this).imageUri } : {}, - ...(spec.algorithmName) ? { AlgorithmName: spec.algorithmName } : {}, - ...(spec.metricDefinitions) ? - { MetricDefinitions: spec.metricDefinitions - .map(metric => ({ Name: metric.name, Regex: metric.regex })) } : {}, + ...(spec.trainingImage ? { TrainingImage: spec.trainingImage.bind(this).imageUri } : {}), + ...(spec.algorithmName ? { AlgorithmName: spec.algorithmName } : {}), + ...(spec.metricDefinitions + ? { MetricDefinitions: spec.metricDefinitions.map((metric) => ({ Name: metric.name, Regex: metric.regex })) } + : {}), }, }; } - private renderInputDataConfig(config: Channel[]): {[key: string]: any} { + private renderInputDataConfig(config: Channel[]): { [key: string]: any } { return { - InputDataConfig: config.map(channel => ({ + InputDataConfig: config.map((channel) => ({ ChannelName: channel.channelName, DataSource: { S3DataSource: { S3Uri: channel.dataSource.s3DataSource.s3Location.bind(this, { forReading: true }).uri, S3DataType: channel.dataSource.s3DataSource.s3DataType, - ...(channel.dataSource.s3DataSource.s3DataDistributionType) ? - { S3DataDistributionType: channel.dataSource.s3DataSource.s3DataDistributionType} : {}, - ...(channel.dataSource.s3DataSource.attributeNames) ? - { AtttributeNames: channel.dataSource.s3DataSource.attributeNames } : {}, + ...(channel.dataSource.s3DataSource.s3DataDistributionType + ? { S3DataDistributionType: channel.dataSource.s3DataSource.s3DataDistributionType } + : {}), + ...(channel.dataSource.s3DataSource.attributeNames ? { AtttributeNames: channel.dataSource.s3DataSource.attributeNames } : {}), }, }, - ...(channel.compressionType) ? { CompressionType: channel.compressionType } : {}, - ...(channel.contentType) ? { ContentType: channel.contentType } : {}, - ...(channel.inputMode) ? { InputMode: channel.inputMode } : {}, - ...(channel.recordWrapperType) ? { RecordWrapperType: channel.recordWrapperType } : {}, + ...(channel.compressionType ? { CompressionType: channel.compressionType } : {}), + ...(channel.contentType ? { ContentType: channel.contentType } : {}), + ...(channel.inputMode ? { InputMode: channel.inputMode } : {}), + ...(channel.recordWrapperType ? { RecordWrapperType: channel.recordWrapperType } : {}), })), }; } - private renderOutputDataConfig(config: OutputDataConfig): {[key: string]: any} { + private renderOutputDataConfig(config: OutputDataConfig): { [key: string]: any } { return { OutputDataConfig: { S3OutputPath: config.s3OutputLocation.bind(this, { forWriting: true }).uri, - ...(config.encryptionKey) ? { KmsKeyId: config.encryptionKey.keyArn } : {}, + ...(config.encryptionKey ? { KmsKeyId: config.encryptionKey.keyArn } : {}), }, }; } - private renderResourceConfig(config: ResourceConfig): {[key: string]: any} { + private renderResourceConfig(config: ResourceConfig): { [key: string]: any } { return { ResourceConfig: { InstanceCount: config.instanceCount, InstanceType: 'ml.' + config.instanceType, VolumeSizeInGB: config.volumeSizeInGB, - ...(config.volumeEncryptionKey) ? { VolumeKmsKeyId: config.volumeEncryptionKey.keyArn } : {}, + ...(config.volumeEncryptionKey ? { VolumeKmsKeyId: config.volumeEncryptionKey.keyArn } : {}), }, }; } - private renderStoppingCondition(config: StoppingCondition): {[key: string]: any} { + private renderStoppingCondition(config: StoppingCondition): { [key: string]: any } { return { StoppingCondition: { MaxRuntimeInSeconds: config.maxRuntime && config.maxRuntime.toSeconds(), @@ -295,56 +295,62 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam }; } - private renderHyperparameters(params: {[key: string]: any} | undefined): {[key: string]: any} { - return (params) ? { HyperParameters: params } : {}; + private renderHyperparameters(params: { [key: string]: any } | undefined): { [key: string]: any } { + return params ? { HyperParameters: params } : {}; } - private renderTags(tags: {[key: string]: any} | undefined): {[key: string]: any} { - return (tags) ? { Tags: Object.keys(tags).map(key => ({ Key: key, Value: tags[key] })) } : {}; + private renderTags(tags: { [key: string]: any } | undefined): { [key: string]: any } { + return tags ? { Tags: Object.keys(tags).map((key) => ({ Key: key, Value: tags[key] })) } : {}; } - private renderVpcConfig(config: VpcConfig | undefined): {[key: string]: any} { - return (config) ? { VpcConfig: { - SecurityGroupIds: Lazy.listValue({ produce: () => (this.securityGroups.map(sg => (sg.securityGroupId))) }), - Subnets: this.subnets, - }} : {}; + private renderVpcConfig(config: VpcConfig | undefined): { [key: string]: any } { + return config + ? { + VpcConfig: { + SecurityGroupIds: Lazy.listValue({ produce: () => this.securityGroups.map((sg) => sg.securityGroupId) }), + Subnets: this.subnets, + }, + } + : {}; } private makePolicyStatements(): iam.PolicyStatement[] { // set the sagemaker role or create new one - this._grantPrincipal = this._role = this.props.role || new iam.Role(this, 'SagemakerRole', { - assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), - inlinePolicies: { - CreateTrainingJob: new iam.PolicyDocument({ - statements: [ - new iam.PolicyStatement({ - actions: [ - 'cloudwatch:PutMetricData', - 'logs:CreateLogStream', - 'logs:PutLogEvents', - 'logs:CreateLogGroup', - 'logs:DescribeLogStreams', - 'ecr:GetAuthorizationToken', - ...this.props.vpcConfig - ? [ - 'ec2:CreateNetworkInterface', - 'ec2:CreateNetworkInterfacePermission', - 'ec2:DeleteNetworkInterface', - 'ec2:DeleteNetworkInterfacePermission', - 'ec2:DescribeNetworkInterfaces', - 'ec2:DescribeVpcs', - 'ec2:DescribeDhcpOptions', - 'ec2:DescribeSubnets', - 'ec2:DescribeSecurityGroups', - ] - : [], - ], - resources: ['*'], // Those permissions cannot be resource-scoped - }), - ], - }), - }, - }); + this._grantPrincipal = this._role = + this.props.role || + new iam.Role(this, 'SagemakerRole', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + inlinePolicies: { + CreateTrainingJob: new iam.PolicyDocument({ + statements: [ + new iam.PolicyStatement({ + actions: [ + 'cloudwatch:PutMetricData', + 'logs:CreateLogStream', + 'logs:PutLogEvents', + 'logs:CreateLogGroup', + 'logs:DescribeLogStreams', + 'ecr:GetAuthorizationToken', + ...(this.props.vpcConfig + ? [ + 'ec2:CreateNetworkInterface', + 'ec2:CreateNetworkInterfacePermission', + 'ec2:DeleteNetworkInterface', + 'ec2:DeleteNetworkInterfacePermission', + 'ec2:DescribeNetworkInterfaces', + 'ec2:DescribeVpcs', + 'ec2:DescribeDhcpOptions', + 'ec2:DescribeSubnets', + 'ec2:DescribeSecurityGroups', + ] + : []), + ], + resources: ['*'], // Those permissions cannot be resource-scoped + }), + ], + }), + }, + }); if (this.props.outputDataConfig.encryptionKey) { this.props.outputDataConfig.encryptionKey.grantEncrypt(this._role); @@ -392,14 +398,18 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam ]; if (this.integrationPattern === sfn.IntegrationPattern.RUN_JOB) { - policyStatements.push(new iam.PolicyStatement({ - actions: ['events:PutTargets', 'events:PutRule', 'events:DescribeRule'], - resources: [stack.formatArn({ - service: 'events', - resource: 'rule', - resourceName: 'StepFunctionsGetEventsForSageMakerTrainingJobsRule', - })], - })); + policyStatements.push( + new iam.PolicyStatement({ + actions: ['events:PutTargets', 'events:PutRule', 'events:DescribeRule'], + resources: [ + stack.formatArn({ + service: 'events', + resource: 'rule', + resourceName: 'StepFunctionsGetEventsForSageMakerTrainingJobsRule', + }), + ], + }), + ); } return policyStatements; diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-task-base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-task-base-types.ts index 1db442e348f75..5acc51f249e60 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-task-base-types.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-task-base-types.ts @@ -11,7 +11,7 @@ import { Construct, Duration } from '@aws-cdk/core'; * Task to train a machine learning model using Amazon SageMaker * @experimental */ -export interface ISageMakerTask extends sfn.IStepFunctionsTask, iam.IGrantable {} +export interface ISageMakerTask extends iam.IGrantable {} /** * Specify the training algorithm and algorithm-specific metadata From 114a2d3078dd8f7ea5d51c969ac12970ac2955ec Mon Sep 17 00:00:00 2001 From: Shiv Lakshminarayan Date: Wed, 3 Jun 2020 00:05:57 -0700 Subject: [PATCH 3/9] add to index, update usage of kms.Key to take in IKey --- .../aws-stepfunctions-tasks/lib/index.ts | 2 + .../lib/sagemaker/create-training-job.ts | 2 +- .../lib/sagemaker/create-transform-job.ts | 108 +++++++++--------- .../sagemaker/sagemaker-task-base-types.ts | 2 +- .../sagemaker/create-transform-job.test.ts | 2 +- 5 files changed, 58 insertions(+), 58 deletions(-) diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts index 7b45086a4e48e..4a1ec5c4555f1 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts @@ -13,6 +13,8 @@ export * from './ecs/run-ecs-fargate-task'; export * from './sagemaker/sagemaker-task-base-types'; export * from './sagemaker/sagemaker-train-task'; export * from './sagemaker/sagemaker-transform-task'; +export * from './sagemaker/create-training-job'; +export * from './sagemaker/create-transform-job'; export * from './start-execution'; export * from './evaluate-expression'; export * from './emr/emr-create-cluster'; diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts index fedb975af4aed..bbe8a19846b34 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts @@ -210,7 +210,7 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam this.securityGroups.push(securityGroup); } - public renderTask(): any { + protected renderTask(): any { return { Resource: integrationResourceArn('sagemaker', 'createTrainingJob', this.integrationPattern), Parameters: sfn.FieldUtils.renderObject(this.renderParameters()), diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts index eb80f742dcb42..7ea5103c2bb32 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts @@ -1,7 +1,7 @@ import * as ec2 from '@aws-cdk/aws-ec2'; import * as iam from '@aws-cdk/aws-iam'; import * as sfn from '@aws-cdk/aws-stepfunctions'; -import { Construct, Stack } from '@aws-cdk/core'; +import { Construct, Size, Stack } from '@aws-cdk/core'; import { integrationResourceArn, validatePatternSupported } from '../private/task-utils'; import { BatchStrategy, S3DataType, TransformInput, TransformOutput, TransformResources } from './sagemaker-task-base-types'; @@ -11,7 +11,6 @@ import { BatchStrategy, S3DataType, TransformInput, TransformOutput, TransformRe * @experimental */ export interface SageMakerCreateTransformJobProps extends sfn.TaskStateBaseProps { - /** * Training Job Name. */ @@ -36,7 +35,7 @@ export interface SageMakerCreateTransformJobProps extends sfn.TaskStateBaseProps * * @default - No environment variables */ - readonly environment?: {[key: string]: string}; + readonly environment?: { [key: string]: string }; /** * Maximum number of parallel requests that can be sent to each instance in a transform job. @@ -51,7 +50,7 @@ export interface SageMakerCreateTransformJobProps extends sfn.TaskStateBaseProps * * @default 6 */ - readonly maxPayloadInMB?: number; + readonly maxPayload?: Size; /** * Name of the model that you want to use for the transform job. @@ -63,7 +62,7 @@ export interface SageMakerCreateTransformJobProps extends sfn.TaskStateBaseProps * * @default - No tags */ - readonly tags?: {[key: string]: string}; + readonly tags?: { [key: string]: string }; /** * Dataset to be transformed and the Amazon S3 location where it is stored. @@ -89,7 +88,6 @@ export interface SageMakerCreateTransformJobProps extends sfn.TaskStateBaseProps * @experimental */ export class SageMakerCreateTransformJob extends sfn.TaskStateBase { - private static readonly SUPPORTED_INTEGRATION_PATTERNS: sfn.IntegrationPattern[] = [ sfn.IntegrationPattern.REQUEST_RESPONSE, sfn.IntegrationPattern.RUN_JOB, @@ -121,15 +119,11 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase { } // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined - this.transformInput = (props.transformInput.transformDataSource.s3DataSource.s3DataType) ? (props.transformInput) : - Object.assign({}, props.transformInput, - { transformDataSource: - { s3DataSource: - { ...props.transformInput.transformDataSource.s3DataSource, - s3DataType: S3DataType.S3_PREFIX, - }, - }, - }); + this.transformInput = props.transformInput.transformDataSource.s3DataSource.s3DataType + ? props.transformInput + : Object.assign({}, props.transformInput, { + transformDataSource: { s3DataSource: { ...props.transformInput.transformDataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } }, + }); // set the default value for the transform resources this.transformResources = props.transformResources || { @@ -140,7 +134,7 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase { this.taskPolicies = this.makePolicyStatements(); } - public renderTask(): any { + protected renderTask(): any { return { Resource: integrationResourceArn('sagemaker', 'createTransformJob', this.integrationPattern), Parameters: sfn.FieldUtils.renderObject(this.renderParameters()), @@ -159,64 +153,64 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase { return this._role; } - private renderParameters(): {[key: string]: any} { + private renderParameters(): { [key: string]: any } { return { - ...(this.props.batchStrategy) ? { BatchStrategy: this.props.batchStrategy } : {}, - ...(this.renderEnvironment(this.props.environment)), - ...(this.props.maxConcurrentTransforms) ? { MaxConcurrentTransforms: this.props.maxConcurrentTransforms } : {}, - ...(this.props.maxPayloadInMB) ? { MaxPayloadInMB: this.props.maxPayloadInMB } : {}, + ...(this.props.batchStrategy ? { BatchStrategy: this.props.batchStrategy } : {}), + ...this.renderEnvironment(this.props.environment), + ...(this.props.maxConcurrentTransforms ? { MaxConcurrentTransforms: this.props.maxConcurrentTransforms } : {}), + ...(this.props.maxPayload ? { MaxPayloadInMB: this.props.maxPayload.toMebibytes() } : {}), ModelName: this.props.modelName, - ...(this.renderTags(this.props.tags)), - ...(this.renderTransformInput(this.transformInput)), + ...this.renderTags(this.props.tags), + ...this.renderTransformInput(this.transformInput), TransformJobName: this.props.transformJobName, - ...(this.renderTransformOutput(this.props.transformOutput)), - ...(this.renderTransformResources(this.transformResources)), + ...this.renderTransformOutput(this.props.transformOutput), + ...this.renderTransformResources(this.transformResources), }; } - private renderTransformInput(input: TransformInput): {[key: string]: any} { + private renderTransformInput(input: TransformInput): { [key: string]: any } { return { TransformInput: { - ...(input.compressionType) ? { CompressionType: input.compressionType } : {}, - ...(input.contentType) ? { ContentType: input.contentType } : {}, + ...(input.compressionType ? { CompressionType: input.compressionType } : {}), + ...(input.contentType ? { ContentType: input.contentType } : {}), DataSource: { S3DataSource: { S3Uri: input.transformDataSource.s3DataSource.s3Uri, S3DataType: input.transformDataSource.s3DataSource.s3DataType, }, }, - ...(input.splitType) ? { SplitType: input.splitType } : {}, + ...(input.splitType ? { SplitType: input.splitType } : {}), }, }; } - private renderTransformOutput(output: TransformOutput): {[key: string]: any} { + private renderTransformOutput(output: TransformOutput): { [key: string]: any } { return { TransformOutput: { S3OutputPath: output.s3OutputPath, - ...(output.encryptionKey) ? { KmsKeyId: output.encryptionKey.keyArn } : {}, - ...(output.accept) ? { Accept: output.accept } : {}, - ...(output.assembleWith) ? { AssembleWith: output.assembleWith } : {}, + ...(output.encryptionKey ? { KmsKeyId: output.encryptionKey.keyArn } : {}), + ...(output.accept ? { Accept: output.accept } : {}), + ...(output.assembleWith ? { AssembleWith: output.assembleWith } : {}), }, }; } - private renderTransformResources(resources: TransformResources): {[key: string]: any} { + private renderTransformResources(resources: TransformResources): { [key: string]: any } { return { TransformResources: { InstanceCount: resources.instanceCount, InstanceType: 'ml.' + resources.instanceType, - ...(resources.volumeKmsKeyId) ? { VolumeKmsKeyId: resources.volumeKmsKeyId.keyArn } : {}, + ...(resources.volumeKmsKeyId ? { VolumeKmsKeyId: resources.volumeKmsKeyId.keyArn } : {}), }, }; } - private renderEnvironment(environment: {[key: string]: any} | undefined): {[key: string]: any} { - return (environment) ? { Environment: environment } : {}; + private renderEnvironment(environment: { [key: string]: any } | undefined): { [key: string]: any } { + return environment ? { Environment: environment } : {}; } - private renderTags(tags: {[key: string]: any} | undefined): {[key: string]: any} { - return (tags) ? { Tags: Object.keys(tags).map(key => ({ Key: key, Value: tags[key] })) } : {}; + private renderTags(tags: { [key: string]: any } | undefined): { [key: string]: any } { + return tags ? { Tags: Object.keys(tags).map((key) => ({ Key: key, Value: tags[key] })) } : {}; } private makePolicyStatements(): iam.PolicyStatement[] { @@ -226,9 +220,7 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase { if (this._role === undefined) { this._role = new iam.Role(this, 'SagemakerTransformRole', { assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), - managedPolicies: [ - iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess'), - ], + managedPolicies: [iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess')], }); } @@ -236,11 +228,13 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase { const policyStatements = [ new iam.PolicyStatement({ actions: ['sagemaker:CreateTransformJob', 'sagemaker:DescribeTransformJob', 'sagemaker:StopTransformJob'], - resources: [stack.formatArn({ - service: 'sagemaker', - resource: 'transform-job', - resourceName: '*', - })], + resources: [ + stack.formatArn({ + service: 'sagemaker', + resource: 'transform-job', + resourceName: '*', + }), + ], }), new iam.PolicyStatement({ actions: ['sagemaker:ListTags'], @@ -256,14 +250,18 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase { ]; if (this.integrationPattern === sfn.IntegrationPattern.RUN_JOB) { - policyStatements.push(new iam.PolicyStatement({ - actions: ['events:PutTargets', 'events:PutRule', 'events:DescribeRule'], - resources: [stack.formatArn({ - service: 'events', - resource: 'rule', - resourceName: 'StepFunctionsGetEventsForSageMakerTransformJobsRule', - }) ], - })); + policyStatements.push( + new iam.PolicyStatement({ + actions: ['events:PutTargets', 'events:PutRule', 'events:DescribeRule'], + resources: [ + stack.formatArn({ + service: 'events', + resource: 'rule', + resourceName: 'StepFunctionsGetEventsForSageMakerTransformJobsRule', + }), + ], + }), + ); } return policyStatements; diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-task-base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-task-base-types.ts index 5acc51f249e60..05451fbf9d8cf 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-task-base-types.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-task-base-types.ts @@ -622,7 +622,7 @@ export interface TransformResources { * * @default - None */ - readonly volumeKmsKeyId?: kms.Key; + readonly volumeKmsKeyId?: kms.IKey; } /** diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-transform-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-transform-job.test.ts index 21d700d6e06dc..923118adccddb 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-transform-job.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-transform-job.test.ts @@ -129,7 +129,7 @@ test('create complex transform job', () => { SOMEVAR: 'myvalue', }, maxConcurrentTransforms: 3, - maxPayloadInMB: 100, + maxPayload: cdk.Size.mebibytes(100), }); // THEN From 7ae75e4a59f211da19e506aecd71193b406d5ba9 Mon Sep 17 00:00:00 2001 From: Shiv Lakshminarayan Date: Thu, 4 Jun 2020 16:32:08 -0700 Subject: [PATCH 4/9] remove older experimental implementations --- .../aws-stepfunctions-tasks/lib/index.ts | 4 +- ...maker-task-base-types.ts => base-types.ts} | 0 .../lib/sagemaker/create-training-job.ts | 41 +- .../lib/sagemaker/create-transform-job.ts | 6 +- .../lib/sagemaker/sagemaker-train-task.ts | 409 ------------------ .../lib/sagemaker/sagemaker-transform-task.ts | 278 ------------ .../sagemaker/integ.sagemaker.expected.json | 402 ----------------- .../test/sagemaker/integ.sagemaker.ts | 34 -- .../sagemaker/sagemaker-training-job.test.ts | 399 ----------------- .../sagemaker/sagemaker-transform-job.test.ts | 242 ----------- 10 files changed, 20 insertions(+), 1795 deletions(-) rename packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/{sagemaker-task-base-types.ts => base-types.ts} (100%) delete mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-train-task.ts delete mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-transform-task.ts delete mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.sagemaker.expected.json delete mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.sagemaker.ts delete mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/sagemaker-training-job.test.ts delete mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/sagemaker-transform-job.test.ts diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts index 4a1ec5c4555f1..5dd2bb1b038d9 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts @@ -10,9 +10,7 @@ export * from './sqs/send-to-queue'; export * from './sqs/send-message'; export * from './ecs/run-ecs-ec2-task'; export * from './ecs/run-ecs-fargate-task'; -export * from './sagemaker/sagemaker-task-base-types'; -export * from './sagemaker/sagemaker-train-task'; -export * from './sagemaker/sagemaker-transform-task'; +export * from './sagemaker/base-types'; export * from './sagemaker/create-training-job'; export * from './sagemaker/create-transform-job'; export * from './start-execution'; diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-task-base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts similarity index 100% rename from packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-task-base-types.ts rename to packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts index bbe8a19846b34..e9064950aeb79 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts @@ -3,16 +3,7 @@ import * as iam from '@aws-cdk/aws-iam'; import * as sfn from '@aws-cdk/aws-stepfunctions'; import { Construct, Duration, Lazy, Stack } from '@aws-cdk/core'; import { integrationResourceArn, validatePatternSupported } from '../private/task-utils'; -import { - AlgorithmSpecification, - Channel, - InputMode, - OutputDataConfig, - ResourceConfig, - S3DataType, - StoppingCondition, - VpcConfig, -} from './sagemaker-task-base-types'; +import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig, S3DataType, StoppingCondition, VpcConfig } from './base-types'; /** * Properties for creating an Amazon SageMaker training job @@ -306,11 +297,11 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam private renderVpcConfig(config: VpcConfig | undefined): { [key: string]: any } { return config ? { - VpcConfig: { - SecurityGroupIds: Lazy.listValue({ produce: () => this.securityGroups.map((sg) => sg.securityGroupId) }), - Subnets: this.subnets, - }, - } + VpcConfig: { + SecurityGroupIds: Lazy.listValue({ produce: () => this.securityGroups.map((sg) => sg.securityGroupId) }), + Subnets: this.subnets, + }, + } : {}; } @@ -333,16 +324,16 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam 'ecr:GetAuthorizationToken', ...(this.props.vpcConfig ? [ - 'ec2:CreateNetworkInterface', - 'ec2:CreateNetworkInterfacePermission', - 'ec2:DeleteNetworkInterface', - 'ec2:DeleteNetworkInterfacePermission', - 'ec2:DescribeNetworkInterfaces', - 'ec2:DescribeVpcs', - 'ec2:DescribeDhcpOptions', - 'ec2:DescribeSubnets', - 'ec2:DescribeSecurityGroups', - ] + 'ec2:CreateNetworkInterface', + 'ec2:CreateNetworkInterfacePermission', + 'ec2:DeleteNetworkInterface', + 'ec2:DeleteNetworkInterfacePermission', + 'ec2:DescribeNetworkInterfaces', + 'ec2:DescribeVpcs', + 'ec2:DescribeDhcpOptions', + 'ec2:DescribeSubnets', + 'ec2:DescribeSecurityGroups', + ] : []), ], resources: ['*'], // Those permissions cannot be resource-scoped diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts index 7ea5103c2bb32..73e4965e45bb3 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts @@ -3,7 +3,7 @@ import * as iam from '@aws-cdk/aws-iam'; import * as sfn from '@aws-cdk/aws-stepfunctions'; import { Construct, Size, Stack } from '@aws-cdk/core'; import { integrationResourceArn, validatePatternSupported } from '../private/task-utils'; -import { BatchStrategy, S3DataType, TransformInput, TransformOutput, TransformResources } from './sagemaker-task-base-types'; +import { BatchStrategy, S3DataType, TransformInput, TransformOutput, TransformResources } from './base-types'; /** * Properties for creating an Amazon SageMaker training job task @@ -122,8 +122,8 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase { this.transformInput = props.transformInput.transformDataSource.s3DataSource.s3DataType ? props.transformInput : Object.assign({}, props.transformInput, { - transformDataSource: { s3DataSource: { ...props.transformInput.transformDataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } }, - }); + transformDataSource: { s3DataSource: { ...props.transformInput.transformDataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } }, + }); // set the default value for the transform resources this.transformResources = props.transformResources || { diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-train-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-train-task.ts deleted file mode 100644 index 758e8a065dc8c..0000000000000 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-train-task.ts +++ /dev/null @@ -1,409 +0,0 @@ -import * as ec2 from '@aws-cdk/aws-ec2'; -import * as iam from '@aws-cdk/aws-iam'; -import * as sfn from '@aws-cdk/aws-stepfunctions'; -import { Duration, Lazy, Stack } from '@aws-cdk/core'; -import { getResourceArn } from '../resource-arn-suffix'; -import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig, - S3DataType, StoppingCondition, VpcConfig } from './sagemaker-task-base-types'; - -/** - * Properties for creating an Amazon SageMaker training job - * - * @experimental - */ -export interface SagemakerTrainTaskProps { - - /** - * Training Job Name. - */ - readonly trainingJobName: string; - - /** - * Role for the Training Job. The role must be granted all necessary permissions for the SageMaker training job to - * be able to operate. - * - * See https://docs.aws.amazon.com/fr_fr/sagemaker/latest/dg/sagemaker-roles.html#sagemaker-roles-createtrainingjob-perms - * - * @default - a role with appropriate permissions will be created. - */ - readonly role?: iam.IRole; - - /** - * The service integration pattern indicates different ways to call SageMaker APIs. - * - * The valid value is either FIRE_AND_FORGET or SYNC. - * - * @default FIRE_AND_FORGET - */ - readonly integrationPattern?: sfn.ServiceIntegrationPattern; - - /** - * Identifies the training algorithm to use. - */ - readonly algorithmSpecification: AlgorithmSpecification; - - /** - * Algorithm-specific parameters that influence the quality of the model. Set hyperparameters before you start the learning process. - * For a list of hyperparameters provided by Amazon SageMaker - * @see https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html - * - * @default - No hyperparameters - */ - readonly hyperparameters?: {[key: string]: any}; - - /** - * Describes the various datasets (e.g. train, validation, test) and the Amazon S3 location where stored. - */ - readonly inputDataConfig: Channel[]; - - /** - * Tags to be applied to the train job. - * - * @default - No tags - */ - readonly tags?: {[key: string]: string}; - - /** - * Identifies the Amazon S3 location where you want Amazon SageMaker to save the results of model training. - */ - readonly outputDataConfig: OutputDataConfig; - - /** - * Specifies the resources, ML compute instances, and ML storage volumes to deploy for model training. - * - * @default - 1 instance of EC2 `M4.XLarge` with `10GB` volume - */ - readonly resourceConfig?: ResourceConfig; - - /** - * Sets a time limit for training. - * - * @default - max runtime of 1 hour - */ - readonly stoppingCondition?: StoppingCondition; - - /** - * Specifies the VPC that you want your training job to connect to. - * - * @default - No VPC - */ - readonly vpcConfig?: VpcConfig; -} - -/** - * Class representing the SageMaker Create Training Job task. - * - * @experimental - */ -export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn.IStepFunctionsTask { - - /** - * Allows specify security group connections for instances of this fleet. - */ - public readonly connections: ec2.Connections = new ec2.Connections(); - - /** - * The Algorithm Specification - */ - private readonly algorithmSpecification: AlgorithmSpecification; - - /** - * The Input Data Config. - */ - private readonly inputDataConfig: Channel[]; - - /** - * The resource config for the task. - */ - private readonly resourceConfig: ResourceConfig; - - /** - * The resource config for the task. - */ - private readonly stoppingCondition: StoppingCondition; - - private readonly vpc?: ec2.IVpc; - private securityGroup?: ec2.ISecurityGroup; - private readonly securityGroups: ec2.ISecurityGroup[] = []; - private readonly subnets?: string[]; - private readonly integrationPattern: sfn.ServiceIntegrationPattern; - private _role?: iam.IRole; - private _grantPrincipal?: iam.IPrincipal; - - constructor(private readonly props: SagemakerTrainTaskProps) { - this.integrationPattern = props.integrationPattern || sfn.ServiceIntegrationPattern.FIRE_AND_FORGET; - - const supportedPatterns = [ - sfn.ServiceIntegrationPattern.FIRE_AND_FORGET, - sfn.ServiceIntegrationPattern.SYNC, - ]; - - if (!supportedPatterns.includes(this.integrationPattern)) { - throw new Error(`Invalid Service Integration Pattern: ${this.integrationPattern} is not supported to call SageMaker.`); - } - - // set the default resource config if not defined. - this.resourceConfig = props.resourceConfig || { - instanceCount: 1, - instanceType: ec2.InstanceType.of(ec2.InstanceClass.M4, ec2.InstanceSize.XLARGE), - volumeSizeInGB: 10, - }; - - // set the stopping condition if not defined - this.stoppingCondition = props.stoppingCondition || { - maxRuntime: Duration.hours(1), - }; - - // check that either algorithm name or image is defined - if ((!props.algorithmSpecification.algorithmName) && (!props.algorithmSpecification.trainingImage)) { - throw new Error('Must define either an algorithm name or training image URI in the algorithm specification'); - } - - // set the input mode to 'File' if not defined - this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ? - ( props.algorithmSpecification ) : - ( { ...props.algorithmSpecification, trainingInputMode: InputMode.FILE } ); - - // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined - this.inputDataConfig = props.inputDataConfig.map(config => { - if (!config.dataSource.s3DataSource.s3DataType) { - return Object.assign({}, config, { dataSource: { s3DataSource: - { ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } } }); - } else { - return config; - } - }); - - // add the security groups to the connections object - if (props.vpcConfig) { - this.vpc = props.vpcConfig.vpc; - this.subnets = (props.vpcConfig.subnets) ? - (this.vpc.selectSubnets(props.vpcConfig.subnets).subnetIds) : this.vpc.selectSubnets().subnetIds; - } - } - - /** - * The execution role for the Sagemaker training job. - * - * Only available after task has been added to a state machine. - */ - public get role(): iam.IRole { - if (this._role === undefined) { - throw new Error('role not available yet--use the object in a Task first'); - } - return this._role; - } - - public get grantPrincipal(): iam.IPrincipal { - if (this._grantPrincipal === undefined) { - throw new Error('Principal not available yet--use the object in a Task first'); - } - return this._grantPrincipal; - } - - /** - * Add the security group to all instances via the launch configuration - * security groups array. - * - * @param securityGroup: The security group to add - */ - public addSecurityGroup(securityGroup: ec2.ISecurityGroup): void { - this.securityGroups.push(securityGroup); - } - - public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig { - // set the sagemaker role or create new one - this._grantPrincipal = this._role = this.props.role || new iam.Role(task, 'SagemakerRole', { - assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), - inlinePolicies: { - CreateTrainingJob: new iam.PolicyDocument({ - statements: [ - new iam.PolicyStatement({ - actions: [ - 'cloudwatch:PutMetricData', - 'logs:CreateLogStream', - 'logs:PutLogEvents', - 'logs:CreateLogGroup', - 'logs:DescribeLogStreams', - 'ecr:GetAuthorizationToken', - ...this.props.vpcConfig - ? [ - 'ec2:CreateNetworkInterface', - 'ec2:CreateNetworkInterfacePermission', - 'ec2:DeleteNetworkInterface', - 'ec2:DeleteNetworkInterfacePermission', - 'ec2:DescribeNetworkInterfaces', - 'ec2:DescribeVpcs', - 'ec2:DescribeDhcpOptions', - 'ec2:DescribeSubnets', - 'ec2:DescribeSecurityGroups', - ] - : [], - ], - resources: ['*'], // Those permissions cannot be resource-scoped - }), - ], - }), - }, - }); - - if (this.props.outputDataConfig.encryptionKey) { - this.props.outputDataConfig.encryptionKey.grantEncrypt(this._role); - } - - if (this.props.resourceConfig && this.props.resourceConfig.volumeEncryptionKey) { - this.props.resourceConfig.volumeEncryptionKey.grant(this._role, 'kms:CreateGrant'); - } - - // create a security group if not defined - if (this.vpc && this.securityGroup === undefined) { - this.securityGroup = new ec2.SecurityGroup(task, 'TrainJobSecurityGroup', { - vpc: this.vpc, - }); - this.connections.addSecurityGroup(this.securityGroup); - this.securityGroups.push(this.securityGroup); - } - - return { - resourceArn: getResourceArn('sagemaker', 'createTrainingJob', this.integrationPattern), - parameters: this.renderParameters(), - policyStatements: this.makePolicyStatements(task), - }; - } - - private renderParameters(): {[key: string]: any} { - return { - TrainingJobName: this.props.trainingJobName, - RoleArn: this._role!.roleArn, - ...(this.renderAlgorithmSpecification(this.algorithmSpecification)), - ...(this.renderInputDataConfig(this.inputDataConfig)), - ...(this.renderOutputDataConfig(this.props.outputDataConfig)), - ...(this.renderResourceConfig(this.resourceConfig)), - ...(this.renderStoppingCondition(this.stoppingCondition)), - ...(this.renderHyperparameters(this.props.hyperparameters)), - ...(this.renderTags(this.props.tags)), - ...(this.renderVpcConfig(this.props.vpcConfig)), - }; - } - - private renderAlgorithmSpecification(spec: AlgorithmSpecification): {[key: string]: any} { - return { - AlgorithmSpecification: { - TrainingInputMode: spec.trainingInputMode, - ...(spec.trainingImage) ? { TrainingImage: spec.trainingImage.bind(this).imageUri } : {}, - ...(spec.algorithmName) ? { AlgorithmName: spec.algorithmName } : {}, - ...(spec.metricDefinitions) ? - { MetricDefinitions: spec.metricDefinitions - .map(metric => ({ Name: metric.name, Regex: metric.regex })) } : {}, - }, - }; - } - - private renderInputDataConfig(config: Channel[]): {[key: string]: any} { - return { - InputDataConfig: config.map(channel => ({ - ChannelName: channel.channelName, - DataSource: { - S3DataSource: { - S3Uri: channel.dataSource.s3DataSource.s3Location.bind(this, { forReading: true }).uri, - S3DataType: channel.dataSource.s3DataSource.s3DataType, - ...(channel.dataSource.s3DataSource.s3DataDistributionType) ? - { S3DataDistributionType: channel.dataSource.s3DataSource.s3DataDistributionType} : {}, - ...(channel.dataSource.s3DataSource.attributeNames) ? - { AtttributeNames: channel.dataSource.s3DataSource.attributeNames } : {}, - }, - }, - ...(channel.compressionType) ? { CompressionType: channel.compressionType } : {}, - ...(channel.contentType) ? { ContentType: channel.contentType } : {}, - ...(channel.inputMode) ? { InputMode: channel.inputMode } : {}, - ...(channel.recordWrapperType) ? { RecordWrapperType: channel.recordWrapperType } : {}, - })), - }; - } - - private renderOutputDataConfig(config: OutputDataConfig): {[key: string]: any} { - return { - OutputDataConfig: { - S3OutputPath: config.s3OutputLocation.bind(this, { forWriting: true }).uri, - ...(config.encryptionKey) ? { KmsKeyId: config.encryptionKey.keyArn } : {}, - }, - }; - } - - private renderResourceConfig(config: ResourceConfig): {[key: string]: any} { - return { - ResourceConfig: { - InstanceCount: config.instanceCount, - InstanceType: 'ml.' + config.instanceType, - VolumeSizeInGB: config.volumeSizeInGB, - ...(config.volumeEncryptionKey) ? { VolumeKmsKeyId: config.volumeEncryptionKey.keyArn } : {}, - }, - }; - } - - private renderStoppingCondition(config: StoppingCondition): {[key: string]: any} { - return { - StoppingCondition: { - MaxRuntimeInSeconds: config.maxRuntime && config.maxRuntime.toSeconds(), - }, - }; - } - - private renderHyperparameters(params: {[key: string]: any} | undefined): {[key: string]: any} { - return (params) ? { HyperParameters: params } : {}; - } - - private renderTags(tags: {[key: string]: any} | undefined): {[key: string]: any} { - return (tags) ? { Tags: Object.keys(tags).map(key => ({ Key: key, Value: tags[key] })) } : {}; - } - - private renderVpcConfig(config: VpcConfig | undefined): {[key: string]: any} { - return (config) ? { VpcConfig: { - SecurityGroupIds: Lazy.listValue({ produce: () => (this.securityGroups.map(sg => (sg.securityGroupId))) }), - Subnets: this.subnets, - }} : {}; - } - - private makePolicyStatements(task: sfn.Task): iam.PolicyStatement[] { - const stack = Stack.of(task); - - // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html - const policyStatements = [ - new iam.PolicyStatement({ - actions: ['sagemaker:CreateTrainingJob', 'sagemaker:DescribeTrainingJob', 'sagemaker:StopTrainingJob'], - resources: [ - stack.formatArn({ - service: 'sagemaker', - resource: 'training-job', - // If the job name comes from input, we cannot target the policy to a particular ARN prefix reliably... - resourceName: sfn.Data.isJsonPathString(this.props.trainingJobName) ? '*' : `${this.props.trainingJobName}*`, - }), - ], - }), - new iam.PolicyStatement({ - actions: ['sagemaker:ListTags'], - resources: ['*'], - }), - new iam.PolicyStatement({ - actions: ['iam:PassRole'], - resources: [this._role!.roleArn], - conditions: { - StringEquals: { 'iam:PassedToService': 'sagemaker.amazonaws.com' }, - }, - }), - ]; - - if (this.integrationPattern === sfn.ServiceIntegrationPattern.SYNC) { - policyStatements.push(new iam.PolicyStatement({ - actions: ['events:PutTargets', 'events:PutRule', 'events:DescribeRule'], - resources: [stack.formatArn({ - service: 'events', - resource: 'rule', - resourceName: 'StepFunctionsGetEventsForSageMakerTrainingJobsRule', - })], - })); - } - - return policyStatements; - } -} diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-transform-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-transform-task.ts deleted file mode 100644 index 5d4449d052a17..0000000000000 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/sagemaker-transform-task.ts +++ /dev/null @@ -1,278 +0,0 @@ -import * as ec2 from '@aws-cdk/aws-ec2'; -import * as iam from '@aws-cdk/aws-iam'; -import * as sfn from '@aws-cdk/aws-stepfunctions'; -import { Stack } from '@aws-cdk/core'; -import { getResourceArn } from '../resource-arn-suffix'; -import { BatchStrategy, S3DataType, TransformInput, TransformOutput, TransformResources } from './sagemaker-task-base-types'; - -/** - * Properties for creating an Amazon SageMaker training job task - * - * @experimental - */ -export interface SagemakerTransformProps { - - /** - * Training Job Name. - */ - readonly transformJobName: string; - - /** - * Role for the Training Job. - * - * @default - A role is created with `AmazonSageMakerFullAccess` managed policy - */ - readonly role?: iam.IRole; - - /** - * The service integration pattern indicates different ways to call SageMaker APIs. - * - * The valid value is either FIRE_AND_FORGET or SYNC. - * - * @default FIRE_AND_FORGET - */ - readonly integrationPattern?: sfn.ServiceIntegrationPattern; - - /** - * Number of records to include in a mini-batch for an HTTP inference request. - * - * @default - No batch strategy - */ - readonly batchStrategy?: BatchStrategy; - - /** - * Environment variables to set in the Docker container. - * - * @default - No environment variables - */ - readonly environment?: {[key: string]: string}; - - /** - * Maximum number of parallel requests that can be sent to each instance in a transform job. - * - * @default - Amazon SageMaker checks the optional execution-parameters to determine the settings for your chosen algorithm. - * If the execution-parameters endpoint is not enabled, the default value is 1. - */ - readonly maxConcurrentTransforms?: number; - - /** - * Maximum allowed size of the payload, in MB. - * - * @default 6 - */ - readonly maxPayloadInMB?: number; - - /** - * Name of the model that you want to use for the transform job. - */ - readonly modelName: string; - - /** - * Tags to be applied to the train job. - * - * @default - No tags - */ - readonly tags?: {[key: string]: string}; - - /** - * Dataset to be transformed and the Amazon S3 location where it is stored. - */ - readonly transformInput: TransformInput; - - /** - * S3 location where you want Amazon SageMaker to save the results from the transform job. - */ - readonly transformOutput: TransformOutput; - - /** - * ML compute instances for the transform job. - * - * @default - 1 instance of type M4.XLarge - */ - readonly transformResources?: TransformResources; -} - -/** - * Class representing the SageMaker Create Training Job task. - * - * @experimental - */ -export class SagemakerTransformTask implements sfn.IStepFunctionsTask { - - /** - * Dataset to be transformed and the Amazon S3 location where it is stored. - */ - private readonly transformInput: TransformInput; - - /** - * ML compute instances for the transform job. - */ - private readonly transformResources: TransformResources; - private readonly integrationPattern: sfn.ServiceIntegrationPattern; - private _role?: iam.IRole; - - constructor(private readonly props: SagemakerTransformProps) { - this.integrationPattern = props.integrationPattern || sfn.ServiceIntegrationPattern.FIRE_AND_FORGET; - - const supportedPatterns = [ - sfn.ServiceIntegrationPattern.FIRE_AND_FORGET, - sfn.ServiceIntegrationPattern.SYNC, - ]; - - if (!supportedPatterns.includes(this.integrationPattern)) { - throw new Error(`Invalid Service Integration Pattern: ${this.integrationPattern} is not supported to call SageMaker.`); - } - - // set the sagemaker role or create new one - if (props.role) { - this._role = props.role; - } - - // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined - this.transformInput = (props.transformInput.transformDataSource.s3DataSource.s3DataType) ? (props.transformInput) : - Object.assign({}, props.transformInput, - { transformDataSource: - { s3DataSource: - { ...props.transformInput.transformDataSource.s3DataSource, - s3DataType: S3DataType.S3_PREFIX, - }, - }, - }); - - // set the default value for the transform resources - this.transformResources = props.transformResources || { - instanceCount: 1, - instanceType: ec2.InstanceType.of(ec2.InstanceClass.M4, ec2.InstanceSize.XLARGE), - }; - } - - public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig { - // create new role if doesn't exist - if (this._role === undefined) { - this._role = new iam.Role(task, 'SagemakerTransformRole', { - assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), - managedPolicies: [ - iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess'), - ], - }); - } - - return { - resourceArn: getResourceArn('sagemaker', 'createTransformJob', this.integrationPattern), - parameters: this.renderParameters(), - policyStatements: this.makePolicyStatements(task), - }; - } - - /** - * The execution role for the Sagemaker training job. - * - * Only available after task has been added to a state machine. - */ - public get role(): iam.IRole { - if (this._role === undefined) { - throw new Error('role not available yet--use the object in a Task first'); - } - return this._role; - } - - private renderParameters(): {[key: string]: any} { - return { - ...(this.props.batchStrategy) ? { BatchStrategy: this.props.batchStrategy } : {}, - ...(this.renderEnvironment(this.props.environment)), - ...(this.props.maxConcurrentTransforms) ? { MaxConcurrentTransforms: this.props.maxConcurrentTransforms } : {}, - ...(this.props.maxPayloadInMB) ? { MaxPayloadInMB: this.props.maxPayloadInMB } : {}, - ModelName: this.props.modelName, - ...(this.renderTags(this.props.tags)), - ...(this.renderTransformInput(this.transformInput)), - TransformJobName: this.props.transformJobName, - ...(this.renderTransformOutput(this.props.transformOutput)), - ...(this.renderTransformResources(this.transformResources)), - }; - } - - private renderTransformInput(input: TransformInput): {[key: string]: any} { - return { - TransformInput: { - ...(input.compressionType) ? { CompressionType: input.compressionType } : {}, - ...(input.contentType) ? { ContentType: input.contentType } : {}, - DataSource: { - S3DataSource: { - S3Uri: input.transformDataSource.s3DataSource.s3Uri, - S3DataType: input.transformDataSource.s3DataSource.s3DataType, - }, - }, - ...(input.splitType) ? { SplitType: input.splitType } : {}, - }, - }; - } - - private renderTransformOutput(output: TransformOutput): {[key: string]: any} { - return { - TransformOutput: { - S3OutputPath: output.s3OutputPath, - ...(output.encryptionKey) ? { KmsKeyId: output.encryptionKey.keyArn } : {}, - ...(output.accept) ? { Accept: output.accept } : {}, - ...(output.assembleWith) ? { AssembleWith: output.assembleWith } : {}, - }, - }; - } - - private renderTransformResources(resources: TransformResources): {[key: string]: any} { - return { - TransformResources: { - InstanceCount: resources.instanceCount, - InstanceType: 'ml.' + resources.instanceType, - ...(resources.volumeKmsKeyId) ? { VolumeKmsKeyId: resources.volumeKmsKeyId.keyArn } : {}, - }, - }; - } - - private renderEnvironment(environment: {[key: string]: any} | undefined): {[key: string]: any} { - return (environment) ? { Environment: environment } : {}; - } - - private renderTags(tags: {[key: string]: any} | undefined): {[key: string]: any} { - return (tags) ? { Tags: Object.keys(tags).map(key => ({ Key: key, Value: tags[key] })) } : {}; - } - - private makePolicyStatements(task: sfn.Task): iam.PolicyStatement[] { - const stack = Stack.of(task); - - // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html - const policyStatements = [ - new iam.PolicyStatement({ - actions: ['sagemaker:CreateTransformJob', 'sagemaker:DescribeTransformJob', 'sagemaker:StopTransformJob'], - resources: [stack.formatArn({ - service: 'sagemaker', - resource: 'transform-job', - resourceName: '*', - })], - }), - new iam.PolicyStatement({ - actions: ['sagemaker:ListTags'], - resources: ['*'], - }), - new iam.PolicyStatement({ - actions: ['iam:PassRole'], - resources: [this.role.roleArn], - conditions: { - StringEquals: { 'iam:PassedToService': 'sagemaker.amazonaws.com' }, - }, - }), - ]; - - if (this.integrationPattern === sfn.ServiceIntegrationPattern.SYNC) { - policyStatements.push(new iam.PolicyStatement({ - actions: ['events:PutTargets', 'events:PutRule', 'events:DescribeRule'], - resources: [stack.formatArn({ - service: 'events', - resource: 'rule', - resourceName: 'StepFunctionsGetEventsForSageMakerTransformJobsRule', - }) ], - })); - } - - return policyStatements; - } -} diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.sagemaker.expected.json b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.sagemaker.expected.json deleted file mode 100644 index 52aeac4dc5de3..0000000000000 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.sagemaker.expected.json +++ /dev/null @@ -1,402 +0,0 @@ -{ - "Resources": { - "EncryptionKey1B843E66": { - "Type": "AWS::KMS::Key", - "Properties": { - "KeyPolicy": { - "Statement": [ - { - "Action": [ - "kms:Create*", - "kms:Describe*", - "kms:Enable*", - "kms:List*", - "kms:Put*", - "kms:Update*", - "kms:Revoke*", - "kms:Disable*", - "kms:Get*", - "kms:Delete*", - "kms:ScheduleKeyDeletion", - "kms:CancelKeyDeletion", - "kms:GenerateDataKey", - "kms:TagResource", - "kms:UntagResource" - ], - "Effect": "Allow", - "Principal": { - "AWS": { - "Fn::Join": [ - "", - [ - "arn:", - { - "Ref": "AWS::Partition" - }, - ":iam::", - { - "Ref": "AWS::AccountId" - }, - ":root" - ] - ] - } - }, - "Resource": "*" - }, - { - "Action": [ - "kms:Decrypt", - "kms:DescribeKey" - ], - "Effect": "Allow", - "Principal": { - "AWS": { - "Fn::GetAtt": [ - "TrainTaskSagemakerRole0A9B1CDD", - "Arn" - ] - } - }, - "Resource": "*" - }, - { - "Action": [ - "kms:Encrypt", - "kms:ReEncrypt*", - "kms:GenerateDataKey*" - ], - "Effect": "Allow", - "Principal": { - "AWS": { - "Fn::GetAtt": [ - "TrainTaskSagemakerRole0A9B1CDD", - "Arn" - ] - } - }, - "Resource": "*" - } - ], - "Version": "2012-10-17" - } - }, - "UpdateReplacePolicy": "Delete", - "DeletionPolicy": "Delete" - }, - "TrainingData3FDB6D34": { - "Type": "AWS::S3::Bucket", - "Properties": { - "BucketEncryption": { - "ServerSideEncryptionConfiguration": [ - { - "ServerSideEncryptionByDefault": { - "KMSMasterKeyID": { - "Fn::GetAtt": [ - "EncryptionKey1B843E66", - "Arn" - ] - }, - "SSEAlgorithm": "aws:kms" - } - } - ] - } - }, - "UpdateReplacePolicy": "Delete", - "DeletionPolicy": "Delete" - }, - "TrainTaskSagemakerRole0A9B1CDD": { - "Type": "AWS::IAM::Role", - "Properties": { - "AssumeRolePolicyDocument": { - "Statement": [ - { - "Action": "sts:AssumeRole", - "Effect": "Allow", - "Principal": { - "Service": "sagemaker.amazonaws.com" - } - } - ], - "Version": "2012-10-17" - }, - "Policies": [ - { - "PolicyDocument": { - "Statement": [ - { - "Action": [ - "cloudwatch:PutMetricData", - "logs:CreateLogStream", - "logs:PutLogEvents", - "logs:CreateLogGroup", - "logs:DescribeLogStreams", - "ecr:GetAuthorizationToken" - ], - "Effect": "Allow", - "Resource": "*" - } - ], - "Version": "2012-10-17" - }, - "PolicyName": "CreateTrainingJob" - } - ] - } - }, - "TrainTaskSagemakerRoleDefaultPolicyA28F72FA": { - "Type": "AWS::IAM::Policy", - "Properties": { - "PolicyDocument": { - "Statement": [ - { - "Action": [ - "s3:GetObject*", - "s3:GetBucket*", - "s3:List*" - ], - "Effect": "Allow", - "Resource": [ - { - "Fn::GetAtt": [ - "TrainingData3FDB6D34", - "Arn" - ] - }, - { - "Fn::Join": [ - "", - [ - { - "Fn::GetAtt": [ - "TrainingData3FDB6D34", - "Arn" - ] - }, - "/data/*" - ] - ] - } - ] - }, - { - "Action": [ - "kms:Decrypt", - "kms:DescribeKey" - ], - "Effect": "Allow", - "Resource": { - "Fn::GetAtt": [ - "EncryptionKey1B843E66", - "Arn" - ] - } - }, - { - "Action": [ - "s3:DeleteObject*", - "s3:PutObject*", - "s3:Abort*" - ], - "Effect": "Allow", - "Resource": [ - { - "Fn::GetAtt": [ - "TrainingData3FDB6D34", - "Arn" - ] - }, - { - "Fn::Join": [ - "", - [ - { - "Fn::GetAtt": [ - "TrainingData3FDB6D34", - "Arn" - ] - }, - "/result/*" - ] - ] - } - ] - }, - { - "Action": [ - "kms:Encrypt", - "kms:ReEncrypt*", - "kms:GenerateDataKey*" - ], - "Effect": "Allow", - "Resource": { - "Fn::GetAtt": [ - "EncryptionKey1B843E66", - "Arn" - ] - } - } - ], - "Version": "2012-10-17" - }, - "PolicyName": "TrainTaskSagemakerRoleDefaultPolicyA28F72FA", - "Roles": [ - { - "Ref": "TrainTaskSagemakerRole0A9B1CDD" - } - ] - } - }, - "StateMachineRoleB840431D": { - "Type": "AWS::IAM::Role", - "Properties": { - "AssumeRolePolicyDocument": { - "Statement": [ - { - "Action": "sts:AssumeRole", - "Effect": "Allow", - "Principal": { - "Service": { - "Fn::Join": [ - "", - [ - "states.", - { - "Ref": "AWS::Region" - }, - ".amazonaws.com" - ] - ] - } - } - } - ], - "Version": "2012-10-17" - } - } - }, - "StateMachineRoleDefaultPolicyDF1E6607": { - "Type": "AWS::IAM::Policy", - "Properties": { - "PolicyDocument": { - "Statement": [ - { - "Action": [ - "sagemaker:CreateTrainingJob", - "sagemaker:DescribeTrainingJob", - "sagemaker:StopTrainingJob" - ], - "Effect": "Allow", - "Resource": { - "Fn::Join": [ - "", - [ - "arn:", - { - "Ref": "AWS::Partition" - }, - ":sagemaker:", - { - "Ref": "AWS::Region" - }, - ":", - { - "Ref": "AWS::AccountId" - }, - ":training-job/MyTrainingJob*" - ] - ] - } - }, - { - "Action": "sagemaker:ListTags", - "Effect": "Allow", - "Resource": "*" - }, - { - "Action": "iam:PassRole", - "Condition": { - "StringEquals": { - "iam:PassedToService": "sagemaker.amazonaws.com" - } - }, - "Effect": "Allow", - "Resource": { - "Fn::GetAtt": [ - "TrainTaskSagemakerRole0A9B1CDD", - "Arn" - ] - } - } - ], - "Version": "2012-10-17" - }, - "PolicyName": "StateMachineRoleDefaultPolicyDF1E6607", - "Roles": [ - { - "Ref": "StateMachineRoleB840431D" - } - ] - } - }, - "StateMachine2E01A3A5": { - "Type": "AWS::StepFunctions::StateMachine", - "Properties": { - "DefinitionString": { - "Fn::Join": [ - "", - [ - "{\"StartAt\":\"TrainTask\",\"States\":{\"TrainTask\":{\"End\":true,\"Parameters\":{\"TrainingJobName\":\"MyTrainingJob\",\"RoleArn\":\"", - { - "Fn::GetAtt": [ - "TrainTaskSagemakerRole0A9B1CDD", - "Arn" - ] - }, - "\",\"AlgorithmSpecification\":{\"TrainingInputMode\":\"File\",\"AlgorithmName\":\"GRADIENT_ASCENT\"},\"InputDataConfig\":[{\"ChannelName\":\"InputData\",\"DataSource\":{\"S3DataSource\":{\"S3Uri\":\"https://s3.", - { - "Ref": "AWS::Region" - }, - ".", - { - "Ref": "AWS::URLSuffix" - }, - "/", - { - "Ref": "TrainingData3FDB6D34" - }, - "/data/\",\"S3DataType\":\"S3Prefix\"}}}],\"OutputDataConfig\":{\"S3OutputPath\":\"https://s3.", - { - "Ref": "AWS::Region" - }, - ".", - { - "Ref": "AWS::URLSuffix" - }, - "/", - { - "Ref": "TrainingData3FDB6D34" - }, - "/result/\"},\"ResourceConfig\":{\"InstanceCount\":1,\"InstanceType\":\"ml.m4.xlarge\",\"VolumeSizeInGB\":10},\"StoppingCondition\":{\"MaxRuntimeInSeconds\":3600}},\"Type\":\"Task\",\"Resource\":\"arn:", - { - "Ref": "AWS::Partition" - }, - ":states:::sagemaker:createTrainingJob\"}}}" - ] - ] - }, - "RoleArn": { - "Fn::GetAtt": [ - "StateMachineRoleB840431D", - "Arn" - ] - } - }, - "DependsOn": [ - "StateMachineRoleDefaultPolicyDF1E6607", - "StateMachineRoleB840431D" - ] - } - } -} \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.sagemaker.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.sagemaker.ts deleted file mode 100644 index 661f1f1bbd006..0000000000000 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.sagemaker.ts +++ /dev/null @@ -1,34 +0,0 @@ -import { Key } from '@aws-cdk/aws-kms'; -import { Bucket, BucketEncryption } from '@aws-cdk/aws-s3'; -import { StateMachine, Task } from '@aws-cdk/aws-stepfunctions'; -import { App, RemovalPolicy, Stack } from '@aws-cdk/core'; -import { S3Location, SagemakerTrainTask } from '../../lib'; - -const app = new App(); -const stack = new Stack(app, 'integ-stepfunctions-sagemaker'); - -const encryptionKey = new Key(stack, 'EncryptionKey', { - removalPolicy: RemovalPolicy.DESTROY, -}); -const trainingData = new Bucket(stack, 'TrainingData', { - encryption: BucketEncryption.KMS, - encryptionKey, - removalPolicy: RemovalPolicy.DESTROY, -}); - -new StateMachine(stack, 'StateMachine', { - definition: new Task(stack, 'TrainTask', { - task: new SagemakerTrainTask({ - algorithmSpecification: { - algorithmName: 'GRADIENT_ASCENT', - }, - inputDataConfig: [{ channelName: 'InputData', dataSource: { - s3DataSource: { - s3Location: S3Location.fromBucket(trainingData, 'data/'), - }, - } }], - outputDataConfig: { s3OutputLocation: S3Location.fromBucket(trainingData, 'result/') }, - trainingJobName: 'MyTrainingJob', - }), - }), -}); diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/sagemaker-training-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/sagemaker-training-job.test.ts deleted file mode 100644 index 58b7d314b535d..0000000000000 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/sagemaker-training-job.test.ts +++ /dev/null @@ -1,399 +0,0 @@ -import '@aws-cdk/assert/jest'; -import * as ec2 from '@aws-cdk/aws-ec2'; -import * as iam from '@aws-cdk/aws-iam'; -import * as kms from '@aws-cdk/aws-kms'; -import * as s3 from '@aws-cdk/aws-s3'; -import * as sfn from '@aws-cdk/aws-stepfunctions'; -import * as cdk from '@aws-cdk/core'; -import * as tasks from '../../lib'; - -let stack: cdk.Stack; - -beforeEach(() => { - // GIVEN - stack = new cdk.Stack(); -}); - -test('create basic training job', () => { - // WHEN - const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask({ - trainingJobName: 'MyTrainJob', - algorithmSpecification: { - algorithmName: 'BlazingText', - }, - inputDataConfig: [ - { - channelName: 'train', - dataSource: { - s3DataSource: { - s3Location: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'InputBucket', 'mybucket'), 'mytrainpath'), - }, - }, - }, - ], - outputDataConfig: { - s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'OutputBucket', 'mybucket'), 'myoutputpath'), - }, - })}); - - // THEN - expect(stack.resolve(task.toStateJson())).toEqual({ - Type: 'Task', - Resource: { - 'Fn::Join': [ - '', - [ - 'arn:', - { - Ref: 'AWS::Partition', - }, - ':states:::sagemaker:createTrainingJob', - ], - ], - }, - End: true, - Parameters: { - AlgorithmSpecification: { - AlgorithmName: 'BlazingText', - TrainingInputMode: 'File', - }, - InputDataConfig: [ - { - ChannelName: 'train', - DataSource: { - S3DataSource: { - S3DataType: 'S3Prefix', - S3Uri: { - 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region'}, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/mytrainpath']], - }, - }, - }, - }, - ], - OutputDataConfig: { - S3OutputPath: { - 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region' }, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/myoutputpath']], - }, - }, - ResourceConfig: { - InstanceCount: 1, - InstanceType: 'ml.m4.xlarge', - VolumeSizeInGB: 10, - }, - RoleArn: { 'Fn::GetAtt': [ 'TrainSagemakerSagemakerRole89E8C593', 'Arn' ] }, - StoppingCondition: { - MaxRuntimeInSeconds: 3600, - }, - TrainingJobName: 'MyTrainJob', - }, - }); -}); - -test('Task throws if WAIT_FOR_TASK_TOKEN is supplied as service integration pattern', () => { - expect(() => { - new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask({ - integrationPattern: sfn.ServiceIntegrationPattern.WAIT_FOR_TASK_TOKEN, - trainingJobName: 'MyTrainJob', - algorithmSpecification: { - algorithmName: 'BlazingText', - }, - inputDataConfig: [ - { - channelName: 'train', - dataSource: { - s3DataSource: { - s3Location: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'InputBucket', 'mybucket'), 'mytrainpath'), - }, - }, - }, - ], - outputDataConfig: { - s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'OutputBucket', 'mybucket'), 'myoutputpath'), - }, - })}); - }).toThrow(/Invalid Service Integration Pattern: WAIT_FOR_TASK_TOKEN is not supported to call SageMaker./i); -}); - -test('create complex training job', () => { - // WHEN - const kmsKey = new kms.Key(stack, 'Key'); - const vpc = new ec2.Vpc(stack, 'VPC'); - const securityGroup = new ec2.SecurityGroup(stack, 'SecurityGroup', { vpc, description: 'My SG' }); - securityGroup.addIngressRule(ec2.Peer.anyIpv4(), ec2.Port.tcp(22), 'allow ssh access from the world'); - - const role = new iam.Role(stack, 'Role', { - assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), - managedPolicies: [ - iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess'), - ], - }); - - const trainTask = new tasks.SagemakerTrainTask({ - trainingJobName: 'MyTrainJob', - integrationPattern: sfn.ServiceIntegrationPattern.SYNC, - role, - algorithmSpecification: { - algorithmName: 'BlazingText', - trainingInputMode: tasks.InputMode.FILE, - metricDefinitions: [ - { - name: 'mymetric', regex: 'regex_pattern', - }, - ], - }, - hyperparameters: { - lr: '0.1', - }, - inputDataConfig: [ - { - channelName: 'train', - contentType: 'image/jpeg', - compressionType: tasks.CompressionType.NONE, - recordWrapperType: tasks.RecordWrapperType.RECORD_IO, - dataSource: { - s3DataSource: { - s3DataType: tasks.S3DataType.S3_PREFIX, - s3Location: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'InputBucketA', 'mybucket'), 'mytrainpath'), - }, - }, - }, - { - channelName: 'test', - contentType: 'image/jpeg', - compressionType: tasks.CompressionType.GZIP, - recordWrapperType: tasks.RecordWrapperType.RECORD_IO, - dataSource: { - s3DataSource: { - s3DataType: tasks.S3DataType.S3_PREFIX, - s3Location: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'InputBucketB', 'mybucket'), 'mytestpath'), - }, - }, - }, - ], - outputDataConfig: { - s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'OutputBucket', 'mybucket'), 'myoutputpath'), - encryptionKey: kmsKey, - }, - resourceConfig: { - instanceCount: 1, - instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), - volumeSizeInGB: 50, - volumeEncryptionKey: kmsKey, - }, - stoppingCondition: { - maxRuntime: cdk.Duration.hours(1), - }, - tags: { - Project: 'MyProject', - }, - vpcConfig: { - vpc, - }, - }); - trainTask.addSecurityGroup(securityGroup); - const task = new sfn.Task(stack, 'TrainSagemaker', { task: trainTask }); - - // THEN - expect(stack.resolve(task.toStateJson())).toEqual({ - Type: 'Task', - Resource: { - 'Fn::Join': [ - '', - [ - 'arn:', - { - Ref: 'AWS::Partition', - }, - ':states:::sagemaker:createTrainingJob.sync', - ], - ], - }, - End: true, - Parameters: { - TrainingJobName: 'MyTrainJob', - RoleArn: { 'Fn::GetAtt': [ 'Role1ABCC5F0', 'Arn' ] }, - AlgorithmSpecification: { - TrainingInputMode: 'File', - AlgorithmName: 'BlazingText', - MetricDefinitions: [ - { Name: 'mymetric', Regex: 'regex_pattern' }, - ], - }, - HyperParameters: { - lr: '0.1', - }, - InputDataConfig: [ - { - ChannelName: 'train', - CompressionType: 'None', - RecordWrapperType: 'RecordIO', - ContentType: 'image/jpeg', - DataSource: { - S3DataSource: { - S3DataType: 'S3Prefix', - S3Uri: { - 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region'}, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/mytrainpath']], - }, - }, - }, - }, - { - ChannelName: 'test', - CompressionType: 'Gzip', - RecordWrapperType: 'RecordIO', - ContentType: 'image/jpeg', - DataSource: { - S3DataSource: { - S3DataType: 'S3Prefix', - S3Uri: { - 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region'}, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/mytestpath']], - }, - }, - }, - }, - ], - OutputDataConfig: { - S3OutputPath: { - 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region' }, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/myoutputpath']], - }, - KmsKeyId: { 'Fn::GetAtt': [ 'Key961B73FD', 'Arn' ] }, - }, - ResourceConfig: { - InstanceCount: 1, - InstanceType: 'ml.p3.2xlarge', - VolumeSizeInGB: 50, - VolumeKmsKeyId: { 'Fn::GetAtt': [ 'Key961B73FD', 'Arn' ] }, - }, - StoppingCondition: { - MaxRuntimeInSeconds: 3600, - }, - Tags: [ - { Key: 'Project', Value: 'MyProject' }, - ], - VpcConfig: { - SecurityGroupIds: [ - { 'Fn::GetAtt': [ 'SecurityGroupDD263621', 'GroupId' ] }, - { 'Fn::GetAtt': [ 'TrainSagemakerTrainJobSecurityGroup7C858EB9', 'GroupId' ] }, - ], - Subnets: [ - { Ref: 'VPCPrivateSubnet1Subnet8BCA10E0' }, - { Ref: 'VPCPrivateSubnet2SubnetCFCDAA7A' }, - ], - }, - }, - }); -}); - -test('pass param to training job', () => { - // WHEN - const role = new iam.Role(stack, 'Role', { - assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), - managedPolicies: [ - iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess'), - ], - }); - - const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask({ - trainingJobName: sfn.Data.stringAt('$.JobName'), - role, - algorithmSpecification: { - algorithmName: 'BlazingText', - trainingInputMode: tasks.InputMode.FILE, - }, - inputDataConfig: [ - { - channelName: 'train', - dataSource: { - s3DataSource: { - s3DataType: tasks.S3DataType.S3_PREFIX, - s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), - }, - }, - }, - ], - outputDataConfig: { - s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath'), - }, - resourceConfig: { - instanceCount: 1, - instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), - volumeSizeInGB: 50, - }, - stoppingCondition: { - maxRuntime: cdk.Duration.hours(1), - }, - })}); - - // THEN - expect(stack.resolve(task.toStateJson())).toEqual({ - Type: 'Task', - Resource: { - 'Fn::Join': [ - '', - [ - 'arn:', - { - Ref: 'AWS::Partition', - }, - ':states:::sagemaker:createTrainingJob', - ], - ], - }, - End: true, - Parameters: { - 'TrainingJobName.$': '$.JobName', - 'RoleArn': { 'Fn::GetAtt': [ 'Role1ABCC5F0', 'Arn' ] }, - 'AlgorithmSpecification': { - TrainingInputMode: 'File', - AlgorithmName: 'BlazingText', - }, - 'InputDataConfig': [ - { - ChannelName: 'train', - DataSource: { - S3DataSource: { - 'S3DataType': 'S3Prefix', - 'S3Uri.$': '$.S3Bucket', - }, - }, - }, - ], - 'OutputDataConfig': { - S3OutputPath: { - 'Fn::Join': ['', ['https://s3.', { Ref: 'AWS::Region' }, '.', { Ref: 'AWS::URLSuffix' }, '/mybucket/myoutputpath']], - }, - }, - 'ResourceConfig': { - InstanceCount: 1, - InstanceType: 'ml.p3.2xlarge', - VolumeSizeInGB: 50, - }, - 'StoppingCondition': { - MaxRuntimeInSeconds: 3600, - }, - }, - }); -}); - -test('Cannot create a SageMaker train task with both algorithm name and image name missing', () => { - - expect(() => new tasks.SagemakerTrainTask({ - trainingJobName: 'myTrainJob', - algorithmSpecification: {}, - inputDataConfig: [ - { - channelName: 'train', - dataSource: { - s3DataSource: { - s3DataType: tasks.S3DataType.S3_PREFIX, - s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), - }, - }, - }, - ], - outputDataConfig: { - s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath/'), - }, - })) - .toThrowError(/Must define either an algorithm name or training image URI in the algorithm specification/); -}); diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/sagemaker-transform-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/sagemaker-transform-job.test.ts deleted file mode 100644 index c08a28bb0c973..0000000000000 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/sagemaker-transform-job.test.ts +++ /dev/null @@ -1,242 +0,0 @@ -import '@aws-cdk/assert/jest'; -import * as ec2 from '@aws-cdk/aws-ec2'; -import * as iam from '@aws-cdk/aws-iam'; -import * as kms from '@aws-cdk/aws-kms'; -import * as sfn from '@aws-cdk/aws-stepfunctions'; -import * as cdk from '@aws-cdk/core'; -import * as tasks from '../../lib'; - -let stack: cdk.Stack; -let role: iam.Role; - -beforeEach(() => { - // GIVEN - stack = new cdk.Stack(); - role = new iam.Role(stack, 'Role', { - assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), - managedPolicies: [ - iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess'), - ], - }); -}); - -test('create basic transform job', () => { - // WHEN - const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ - transformJobName: 'MyTransformJob', - modelName: 'MyModelName', - transformInput: { - transformDataSource: { - s3DataSource: { - s3Uri: 's3://inputbucket/prefix', - }, - }, - }, - transformOutput: { - s3OutputPath: 's3://outputbucket/prefix', - }, - }) }); - - // THEN - expect(stack.resolve(task.toStateJson())).toEqual({ - Type: 'Task', - Resource: { - 'Fn::Join': [ - '', - [ - 'arn:', - { - Ref: 'AWS::Partition', - }, - ':states:::sagemaker:createTransformJob', - ], - ], - }, - End: true, - Parameters: { - TransformJobName: 'MyTransformJob', - ModelName: 'MyModelName', - TransformInput: { - DataSource: { - S3DataSource: { - S3Uri: 's3://inputbucket/prefix', - S3DataType: 'S3Prefix', - }, - }, - }, - TransformOutput: { - S3OutputPath: 's3://outputbucket/prefix', - }, - TransformResources: { - InstanceCount: 1, - InstanceType: 'ml.m4.xlarge', - }, - }, - }); -}); - -test('Task throws if WAIT_FOR_TASK_TOKEN is supplied as service integration pattern', () => { - expect(() => { - new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ - integrationPattern: sfn.ServiceIntegrationPattern.WAIT_FOR_TASK_TOKEN, - transformJobName: 'MyTransformJob', - modelName: 'MyModelName', - transformInput: { - transformDataSource: { - s3DataSource: { - s3Uri: 's3://inputbucket/prefix', - }, - }, - }, - transformOutput: { - s3OutputPath: 's3://outputbucket/prefix', - }, - }) }); - }).toThrow(/Invalid Service Integration Pattern: WAIT_FOR_TASK_TOKEN is not supported to call SageMaker./i); -}); - -test('create complex transform job', () => { - // WHEN - const kmsKey = new kms.Key(stack, 'Key'); - const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ - transformJobName: 'MyTransformJob', - modelName: 'MyModelName', - integrationPattern: sfn.ServiceIntegrationPattern.SYNC, - role, - transformInput: { - transformDataSource: { - s3DataSource: { - s3Uri: 's3://inputbucket/prefix', - s3DataType: tasks.S3DataType.S3_PREFIX, - }, - }, - }, - transformOutput: { - s3OutputPath: 's3://outputbucket/prefix', - encryptionKey: kmsKey, - }, - transformResources: { - instanceCount: 1, - instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), - volumeKmsKeyId: kmsKey, - }, - tags: { - Project: 'MyProject', - }, - batchStrategy: tasks.BatchStrategy.MULTI_RECORD, - environment: { - SOMEVAR: 'myvalue', - }, - maxConcurrentTransforms: 3, - maxPayloadInMB: 100, - }) }); - - // THEN - expect(stack.resolve(task.toStateJson())).toEqual({ - Type: 'Task', - Resource: { - 'Fn::Join': [ - '', - [ - 'arn:', - { - Ref: 'AWS::Partition', - }, - ':states:::sagemaker:createTransformJob.sync', - ], - ], - }, - End: true, - Parameters: { - TransformJobName: 'MyTransformJob', - ModelName: 'MyModelName', - TransformInput: { - DataSource: { - S3DataSource: { - S3Uri: 's3://inputbucket/prefix', - S3DataType: 'S3Prefix', - }, - }, - }, - TransformOutput: { - S3OutputPath: 's3://outputbucket/prefix', - KmsKeyId: { 'Fn::GetAtt': [ 'Key961B73FD', 'Arn' ] }, - }, - TransformResources: { - InstanceCount: 1, - InstanceType: 'ml.p3.2xlarge', - VolumeKmsKeyId: { 'Fn::GetAtt': [ 'Key961B73FD', 'Arn' ] }, - }, - Tags: [ - { Key: 'Project', Value: 'MyProject' }, - ], - MaxConcurrentTransforms: 3, - MaxPayloadInMB: 100, - Environment: { - SOMEVAR: 'myvalue', - }, - BatchStrategy: 'MultiRecord', - }, - }); -}); - -test('pass param to transform job', () => { - // WHEN - const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ - transformJobName: sfn.Data.stringAt('$.TransformJobName'), - modelName: sfn.Data.stringAt('$.ModelName'), - role, - transformInput: { - transformDataSource: { - s3DataSource: { - s3Uri: 's3://inputbucket/prefix', - s3DataType: tasks.S3DataType.S3_PREFIX, - }, - }, - }, - transformOutput: { - s3OutputPath: 's3://outputbucket/prefix', - }, - transformResources: { - instanceCount: 1, - instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), - }, - }) }); - - // THEN - expect(stack.resolve(task.toStateJson())).toEqual({ - Type: 'Task', - Resource: { - 'Fn::Join': [ - '', - [ - 'arn:', - { - Ref: 'AWS::Partition', - }, - ':states:::sagemaker:createTransformJob', - ], - ], - }, - End: true, - Parameters: { - 'TransformJobName.$': '$.TransformJobName', - 'ModelName.$': '$.ModelName', - 'TransformInput': { - DataSource: { - S3DataSource: { - S3Uri: 's3://inputbucket/prefix', - S3DataType: 'S3Prefix', - }, - }, - }, - 'TransformOutput': { - S3OutputPath: 's3://outputbucket/prefix', - }, - 'TransformResources': { - InstanceCount: 1, - InstanceType: 'ml.p3.2xlarge', - }, - }, - }); -}); From 7c09a93f728c396f898179db07ecacf49aa74372 Mon Sep 17 00:00:00 2001 From: Shiv Lakshminarayan Date: Thu, 4 Jun 2020 23:24:12 -0700 Subject: [PATCH 5/9] update volumeSizeInGB to just be volumeSize --- .../aws-stepfunctions-tasks/README.md | 94 +++++++++---------- .../lib/sagemaker/base-types.ts | 4 +- .../lib/sagemaker/create-training-job.ts | 36 +++---- .../lib/sagemaker/create-transform-job.ts | 4 +- .../sagemaker/create-training-job.test.ts | 4 +- 5 files changed, 68 insertions(+), 74 deletions(-) diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/README.md b/packages/@aws-cdk/aws-stepfunctions-tasks/README.md index c8482f9e57f09..55e1f97989a64 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/README.md +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/README.md @@ -617,37 +617,33 @@ Step Functions supports [AWS SageMaker](https://docs.aws.amazon.com/step-functio You can call the [`CreateTrainingJob`](https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html) API from a `Task` state. ```ts -new sfn.Task(stack, 'TrainSagemaker', { - task: new tasks.SagemakerTrainTask({ - trainingJobName: sfn.Data.stringAt('$.JobName'), - role, - algorithmSpecification: { - algorithmName: 'BlazingText', - trainingInputMode: tasks.InputMode.FILE, - }, - inputDataConfig: [ - { - channelName: 'train', - dataSource: { - s3DataSource: { - s3DataType: tasks.S3DataType.S3_PREFIX, - s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), - }, - }, +new sfn.SagemakerTrainTask(this, 'TrainSagemaker', { + trainingJobName: sfn.Data.stringAt('$.JobName'), + role, + algorithmSpecification: { + algorithmName: 'BlazingText', + trainingInputMode: tasks.InputMode.FILE, + }, + inputDataConfig: [{ + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3_PREFIX, + s3Location: tasks.S3Location.fromJsonExpression('$.S3Bucket'), }, - ], - outputDataConfig: { - s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath'), - }, - resourceConfig: { - instanceCount: 1, - instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), - volumeSizeInGB: 50, }, - stoppingCondition: { - maxRuntime: cdk.Duration.hours(1), - }, - }), + }], + outputDataConfig: { + s3OutputLocation: tasks.S3Location.fromBucket(s3.Bucket.fromBucketName(stack, 'Bucket', 'mybucket'), 'myoutputpath'), + }, + resourceConfig: { + instanceCount: 1, + instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), + volumeSizeInGB: 50, + }, + stoppingCondition: { + maxRuntime: cdk.Duration.hours(1), + }, }); ``` @@ -656,29 +652,27 @@ new sfn.Task(stack, 'TrainSagemaker', { You can call the [`CreateTransformJob`](https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTransformJob.html) API from a `Task` state. ```ts -const transformJob = new tasks.SagemakerTransformTask( - transformJobName: "MyTransformJob", - modelName: "MyModelName", - role, - transformInput: { - transformDataSource: { - s3DataSource: { - s3Uri: 's3://inputbucket/train', - s3DataType: S3DataType.S3Prefix, - } - } - }, - transformOutput: { - s3OutputPath: 's3://outputbucket/TransformJobOutputPath', - }, - transformResources: { - instanceCount: 1, - instanceType: ec2.InstanceType.of(ec2.InstanceClass.M4, ec2.InstanceSize.XLarge), +new sfn.SagemakerTransformTask(this, 'Batch Inference', { + transformJobName: 'MyTransformJob', + modelName: 'MyModelName', + role, + transformInput: { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/train', + s3DataType: S3DataType.S3Prefix, + } + } + }, + transformOutput: { + s3OutputPath: 's3://outputbucket/TransformJobOutputPath', + }, + transformResources: { + instanceCount: 1, + instanceType: ec2.InstanceType.of(ec2.InstanceClass.M4, ec2.InstanceSize.XLarge), + } }); -const task = new sfn.Task(this, 'Batch Inference', { - task: transformJob -}); ``` ## SNS diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts index 05451fbf9d8cf..7f19e5eeaf120 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts @@ -5,7 +5,7 @@ import * as iam from '@aws-cdk/aws-iam'; import * as kms from '@aws-cdk/aws-kms'; import * as s3 from '@aws-cdk/aws-s3'; import * as sfn from '@aws-cdk/aws-stepfunctions'; -import { Construct, Duration } from '@aws-cdk/core'; +import { Construct, Duration, Size } from '@aws-cdk/core'; /** * Task to train a machine learning model using Amazon SageMaker @@ -230,7 +230,7 @@ export interface ResourceConfig { * * @default 10 GB EBS volume. */ - readonly volumeSizeInGB: number; + readonly volumeSize: Size; } /** diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts index e9064950aeb79..f9c1309f59ce8 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts @@ -1,7 +1,7 @@ import * as ec2 from '@aws-cdk/aws-ec2'; import * as iam from '@aws-cdk/aws-iam'; import * as sfn from '@aws-cdk/aws-stepfunctions'; -import { Construct, Duration, Lazy, Stack } from '@aws-cdk/core'; +import { Construct, Duration, Lazy, Size, Stack } from '@aws-cdk/core'; import { integrationResourceArn, validatePatternSupported } from '../private/task-utils'; import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig, S3DataType, StoppingCondition, VpcConfig } from './base-types'; @@ -136,7 +136,7 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam this.resourceConfig = props.resourceConfig || { instanceCount: 1, instanceType: ec2.InstanceType.of(ec2.InstanceClass.M4, ec2.InstanceSize.XLARGE), - volumeSizeInGB: 10, + volumeSize: Size.gibibytes(10), }; // set the stopping condition if not defined @@ -272,7 +272,7 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam ResourceConfig: { InstanceCount: config.instanceCount, InstanceType: 'ml.' + config.instanceType, - VolumeSizeInGB: config.volumeSizeInGB, + VolumeSizeInGB: config.volumeSize.toGibibytes(), ...(config.volumeEncryptionKey ? { VolumeKmsKeyId: config.volumeEncryptionKey.keyArn } : {}), }, }; @@ -297,11 +297,11 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam private renderVpcConfig(config: VpcConfig | undefined): { [key: string]: any } { return config ? { - VpcConfig: { - SecurityGroupIds: Lazy.listValue({ produce: () => this.securityGroups.map((sg) => sg.securityGroupId) }), - Subnets: this.subnets, - }, - } + VpcConfig: { + SecurityGroupIds: Lazy.listValue({ produce: () => this.securityGroups.map((sg) => sg.securityGroupId) }), + Subnets: this.subnets, + }, + } : {}; } @@ -324,16 +324,16 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam 'ecr:GetAuthorizationToken', ...(this.props.vpcConfig ? [ - 'ec2:CreateNetworkInterface', - 'ec2:CreateNetworkInterfacePermission', - 'ec2:DeleteNetworkInterface', - 'ec2:DeleteNetworkInterfacePermission', - 'ec2:DescribeNetworkInterfaces', - 'ec2:DescribeVpcs', - 'ec2:DescribeDhcpOptions', - 'ec2:DescribeSubnets', - 'ec2:DescribeSecurityGroups', - ] + 'ec2:CreateNetworkInterface', + 'ec2:CreateNetworkInterfacePermission', + 'ec2:DeleteNetworkInterface', + 'ec2:DeleteNetworkInterfacePermission', + 'ec2:DescribeNetworkInterfaces', + 'ec2:DescribeVpcs', + 'ec2:DescribeDhcpOptions', + 'ec2:DescribeSubnets', + 'ec2:DescribeSecurityGroups', + ] : []), ], resources: ['*'], // Those permissions cannot be resource-scoped diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts index 73e4965e45bb3..816fb2ec45e52 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts @@ -122,8 +122,8 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase { this.transformInput = props.transformInput.transformDataSource.s3DataSource.s3DataType ? props.transformInput : Object.assign({}, props.transformInput, { - transformDataSource: { s3DataSource: { ...props.transformInput.transformDataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } }, - }); + transformDataSource: { s3DataSource: { ...props.transformInput.transformDataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } }, + }); // set the default value for the transform resources this.transformResources = props.transformResources || { diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts index a186a4e4917fb..4f02f9ac048a1 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-training-job.test.ts @@ -178,7 +178,7 @@ test('create complex training job', () => { resourceConfig: { instanceCount: 1, instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), - volumeSizeInGB: 50, + volumeSize: cdk.Size.gibibytes(50), volumeEncryptionKey: kmsKey, }, stoppingCondition: { @@ -317,7 +317,7 @@ test('pass param to training job', () => { resourceConfig: { instanceCount: 1, instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), - volumeSizeInGB: 50, + volumeSize: cdk.Size.gibibytes(50), }, stoppingCondition: { maxRuntime: cdk.Duration.hours(1), From 3a0fae3968a2a8b2bc6769858eec627687bc3eac Mon Sep 17 00:00:00 2001 From: Shiv Lakshminarayan Date: Fri, 5 Jun 2020 00:10:28 -0700 Subject: [PATCH 6/9] update integ test to include some rough verification steps --- .../integ.create-training-job.expected.json | 25 +++++++++++------ .../sagemaker/integ.create-training-job.ts | 28 ++++++++++++++++--- 2 files changed, 40 insertions(+), 13 deletions(-) diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.expected.json b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.expected.json index b8af9b15f61bf..cf95e9f59a16e 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.expected.json +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.expected.json @@ -304,7 +304,7 @@ { "Ref": "AWS::AccountId" }, - ":training-job/MyTrainingJob*" + ":training-job/mytrainingjob*" ] ] } @@ -343,6 +343,12 @@ "StateMachine2E01A3A5": { "Type": "AWS::StepFunctions::StateMachine", "Properties": { + "RoleArn": { + "Fn::GetAtt": [ + "StateMachineRoleB840431D", + "Arn" + ] + }, "DefinitionString": { "Fn::Join": [ "", @@ -351,14 +357,14 @@ { "Ref": "AWS::Partition" }, - ":states:::sagemaker:createTrainingJob\",\"Parameters\":{\"TrainingJobName\":\"MyTrainingJob\",\"RoleArn\":\"", + ":states:::sagemaker:createTrainingJob\",\"Parameters\":{\"TrainingJobName\":\"mytrainingjob\",\"RoleArn\":\"", { "Fn::GetAtt": [ "TrainTaskSagemakerRole0A9B1CDD", "Arn" ] }, - "\",\"AlgorithmSpecification\":{\"TrainingInputMode\":\"File\",\"AlgorithmName\":\"GRADIENT_ASCENT\"},\"InputDataConfig\":[{\"ChannelName\":\"InputData\",\"DataSource\":{\"S3DataSource\":{\"S3Uri\":\"https://s3.", + "\",\"AlgorithmSpecification\":{\"TrainingInputMode\":\"File\",\"AlgorithmName\":\"arn:aws:sagemaker:us-east-1:865070037744:algorithm/scikit-decision-trees-15423055-57b73412d2e93e9239e4e16f83298b8f\"},\"InputDataConfig\":[{\"ChannelName\":\"InputData\",\"DataSource\":{\"S3DataSource\":{\"S3Uri\":\"https://s3.", { "Ref": "AWS::Region" }, @@ -385,12 +391,6 @@ "/result/\"},\"ResourceConfig\":{\"InstanceCount\":1,\"InstanceType\":\"ml.m4.xlarge\",\"VolumeSizeInGB\":10},\"StoppingCondition\":{\"MaxRuntimeInSeconds\":3600}}}}}" ] ] - }, - "RoleArn": { - "Fn::GetAtt": [ - "StateMachineRoleB840431D", - "Arn" - ] } }, "DependsOn": [ @@ -398,5 +398,12 @@ "StateMachineRoleB840431D" ] } + }, + "Outputs": { + "stateMachineArn": { + "Value": { + "Ref": "StateMachine2E01A3A5" + } + } } } \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.ts index 74d89a5a5674c..28e4e65ff0e1e 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/integ.create-training-job.ts @@ -1,10 +1,26 @@ import { Key } from '@aws-cdk/aws-kms'; import { Bucket, BucketEncryption } from '@aws-cdk/aws-s3'; import { StateMachine } from '@aws-cdk/aws-stepfunctions'; -import { App, RemovalPolicy, Stack } from '@aws-cdk/core'; +import { App, CfnOutput, RemovalPolicy, Stack } from '@aws-cdk/core'; import { S3Location } from '../../lib'; import { SageMakerCreateTrainingJob } from '../../lib/sagemaker/create-training-job'; +/* + * Creates a state machine with a task state to create a training job in AWS SageMaker + * SageMaker jobs need training algorithms. These can be found in the AWS marketplace + * or created. + * + * Subscribe to demo Algorithm vended by Amazon (free): + * https://aws.amazon.com/marketplace/ai/procurement?productId=cc5186a0-b8d6-4750-a9bb-1dcdf10e787a + * FIXME - create Input data pertinent for the training model and insert into S3 location specified in inputDataConfig. + * + * Stack verification steps: + * The generated State Machine can be executed from the CLI (or Step Functions console) + * and runs with an execution status of `Succeeded`. + * + * -- aws stepfunctions start-execution --state-machine-arn provides execution arn + * -- aws stepfunctions describe-execution --execution-arn returns a status of `Succeeded` + */ const app = new App(); const stack = new Stack(app, 'integ-stepfunctions-sagemaker'); @@ -17,10 +33,10 @@ const trainingData = new Bucket(stack, 'TrainingData', { removalPolicy: RemovalPolicy.DESTROY, }); -new StateMachine(stack, 'StateMachine', { +const sm = new StateMachine(stack, 'StateMachine', { definition: new SageMakerCreateTrainingJob(stack, 'TrainTask', { algorithmSpecification: { - algorithmName: 'GRADIENT_ASCENT', + algorithmName: 'arn:aws:sagemaker:us-east-1:865070037744:algorithm/scikit-decision-trees-15423055-57b73412d2e93e9239e4e16f83298b8f', }, inputDataConfig: [{ channelName: 'InputData', dataSource: { s3DataSource: { @@ -28,6 +44,10 @@ new StateMachine(stack, 'StateMachine', { }, } }], outputDataConfig: { s3OutputLocation: S3Location.fromBucket(trainingData, 'result/') }, - trainingJobName: 'MyTrainingJob', + trainingJobName: 'mytrainingjob', }), }); + +new CfnOutput(stack, 'stateMachineArn', { + value: sm.stateMachineArn, +}); From b9c9923a460865f0e169f9f79d42732dfc018a46 Mon Sep 17 00:00:00 2001 From: Shiv Lakshminarayan Date: Fri, 5 Jun 2020 00:26:38 -0700 Subject: [PATCH 7/9] update README --- packages/@aws-cdk/aws-stepfunctions-tasks/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/README.md b/packages/@aws-cdk/aws-stepfunctions-tasks/README.md index 55e1f97989a64..e0e89b4ecd924 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/README.md +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/README.md @@ -639,7 +639,7 @@ new sfn.SagemakerTrainTask(this, 'TrainSagemaker', { resourceConfig: { instanceCount: 1, instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), - volumeSizeInGB: 50, + volumeSize: cdk.Size.gibibytes(50), }, stoppingCondition: { maxRuntime: cdk.Duration.hours(1), From 982b5cf941ed83079469719b87587d6d92c70c28 Mon Sep 17 00:00:00 2001 From: Shiv Lakshminarayan Date: Mon, 8 Jun 2020 20:25:44 -0700 Subject: [PATCH 8/9] PR feedback --- .../aws-stepfunctions-tasks/lib/sagemaker/base-types.ts | 2 +- .../lib/sagemaker/create-transform-job.ts | 2 +- .../test/sagemaker/create-transform-job.test.ts | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts index 7f19e5eeaf120..6f1c5f03dcc37 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/base-types.ts @@ -622,7 +622,7 @@ export interface TransformResources { * * @default - None */ - readonly volumeKmsKeyId?: kms.IKey; + readonly volumeEncryptionKey?: kms.IKey; } /** diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts index 816fb2ec45e52..111a15500443e 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-transform-job.ts @@ -200,7 +200,7 @@ export class SageMakerCreateTransformJob extends sfn.TaskStateBase { TransformResources: { InstanceCount: resources.instanceCount, InstanceType: 'ml.' + resources.instanceType, - ...(resources.volumeKmsKeyId ? { VolumeKmsKeyId: resources.volumeKmsKeyId.keyArn } : {}), + ...(resources.volumeEncryptionKey ? { VolumeKmsKeyId: resources.volumeEncryptionKey.keyArn } : {}), }, }; } diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-transform-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-transform-job.test.ts index 923118adccddb..c53233523cfa7 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-transform-job.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker/create-transform-job.test.ts @@ -119,7 +119,7 @@ test('create complex transform job', () => { transformResources: { instanceCount: 1, instanceType: ec2.InstanceType.of(ec2.InstanceClass.P3, ec2.InstanceSize.XLARGE2), - volumeKmsKeyId: kmsKey, + volumeEncryptionKey: kmsKey, }, tags: { Project: 'MyProject', From 8c4f309a89d55e48d1d963854e41334e5cc6cc73 Mon Sep 17 00:00:00 2001 From: Shiv Lakshminarayan Date: Mon, 8 Jun 2020 21:10:53 -0700 Subject: [PATCH 9/9] remove Object.assign usage and use the spread operator instead --- .../lib/sagemaker/create-training-job.ts | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts index f9c1309f59ce8..f541a0e692a4f 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker/create-training-job.ts @@ -157,7 +157,10 @@ export class SageMakerCreateTrainingJob extends sfn.TaskStateBase implements iam // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined this.inputDataConfig = props.inputDataConfig.map((config) => { if (!config.dataSource.s3DataSource.s3DataType) { - return Object.assign({}, config, { dataSource: { s3DataSource: { ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } } }); + return { + ...config, + dataSource: { s3DataSource: { ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } }, + }; } else { return config; }