From 67fd8501b20d0e00da31f6ecd1d37df441dddff8 Mon Sep 17 00:00:00 2001 From: Matt McClean Date: Wed, 5 Jun 2019 23:03:54 +0100 Subject: [PATCH 1/8] added inital SageMaker task --- .../aws-stepfunctions-tasks/lib/index.ts | 4 +- .../lib/sagemaker-base-types.ts | 243 ++++++++++++++++++ .../lib/sagemaker-tasks.ts | 70 +++++ .../aws-stepfunctions-tasks/package-lock.json | 47 ++++ .../aws-stepfunctions-tasks/package.json | 1 + 5 files changed, 364 insertions(+), 1 deletion(-) create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-base-types.ts create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-tasks.ts diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts index 0decc8f601c18..9601f8a35a51c 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts @@ -5,4 +5,6 @@ export * from './run-ecs-task-base-types'; export * from './publish-to-topic'; export * from './send-to-queue'; export * from './run-ecs-ec2-task'; -export * from './run-ecs-fargate-task'; \ No newline at end of file +export * from './run-ecs-fargate-task'; +export * from './sagemaker-base-types'; +export * from './sagemaker-tasks'; \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-base-types.ts new file mode 100644 index 0000000000000..79bcdc7811a52 --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-base-types.ts @@ -0,0 +1,243 @@ +import ec2 = require('@aws-cdk/aws-ec2'); +import kms = require('@aws-cdk/aws-kms'); + +export interface AlgorithmSpecification { + + /** + * Name of the algorithm resource to use for the training job. + */ + readonly algorithmName?: string; + + /** + * List of metric definition objects. Each object specifies the metric name and regular expressions used to parse algorithm logs. + */ + readonly metricDefinitions?: MetricDefinition[]; + + /** + * Registry path of the Docker image that contains the training algorithm. + */ + readonly trainingImage?: string; + + /** + * Input mode that the algorithm supports. + */ + readonly trainingInputMode: InputMode.File | InputMode.Pipe; +} + +export interface Channel { + + /** + * Name of the channel + */ + readonly channelName: string; + + /** + * Compression type if training data is compressed + */ + readonly compressionType?: CompressionType.None | CompressionType.Gzip; + + /** + * Content type + */ + readonly contentType?: string; + + /** + * Location of the data channel + */ + readonly dataSource: DataSource; + + /** + * Input mode to use for the data channel in a training job. + */ + readonly inputMode?: InputMode.File | InputMode.Pipe; + + /** + * Record wrapper type + */ + readonly recordWrapperType?: RecordWrapperType.None | RecordWrapperType.RecordIO; + + /** + * Shuffle config option for input data in a channel. + */ + readonly shuffleConfig?: ShuffleConfig; +} + +export interface ShuffleConfig { + /** + * Determines the shuffling order. + */ + readonly seed: number; +} + +export interface DataSource { + /** + * S3 location of the data source that is associated with a channel. + */ + readonly s3DataSource: S3DataSource; +} + +export interface S3DataSource { + /** + * List of one or more attribute names to use that are found in a specified augmented manifest file. + */ + readonly attributeNames?: string[]; + + /** + * S3 Data Distribution Type + */ + readonly s3DataDistributionType?: S3DataDistributionType.FullyReplicated | S3DataDistributionType.ShardedByS3Key; + + /** + * S3 Data Type + */ + readonly s3DataType: S3DataType; + + /** + * S3 Uri + */ + readonly s3Uri: string; +} + +export interface OutputDataConfig { + /** + * Optional KMS encryption key that Amazon SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side encryption. + */ + readonly encryptionKey?: kms.IKey; + + /** + * Identifies the S3 path where you want Amazon SageMaker to store the model artifacts. + */ + readonly s3OutputPath: string; +} + +export interface StoppingCondition { + /** + * The maximum length of time, in seconds, that the training or compilation job can run. + */ + readonly maxRuntimeInSeconds?: number; +} + +export interface Tag { + /** + * Key tag. + */ + readonly key: string; + + /** + * Value tag. + */ + readonly value: string; +} + +export interface ResourceConfig { + + /** + * The number of ML compute instances to use. + */ + readonly instanceCount: number; + + /** + * ML compute instance type. + */ + readonly instanceType: ec2.InstanceType; + + /** + * KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) that run the training job. + */ + readonly volumeKmsKeyId?: kms.IKey; + + /** + * Size of the ML storage volume that you want to provision. + */ + readonly volumeSizeInGB: number; +} + +export interface VpcConfig { + /** + * VPC security groups. + */ + readonly securityGroups: ec2.ISecurityGroup[]; + + /** + * VPC subnets. + */ + readonly subnetSelection: ec2.SubnetSelection; +} + +export interface MetricDefinition { + + /** + * Name of the metric. + */ + readonly name: string; + + /** + * Regular expression that searches the output of a training job and gets the value of the metric. + */ + readonly regex: string; +} + +export enum S3DataType { + /** + * Manifest File Data Type + */ + ManifestFile = 'ManifestFile', + + /** + * S3 Prefix Data Type + */ + S3Prefix = 'S3Prefix', + + /** + * Augmented Manifest File Data Type + */ + AugmentedManifestFile = 'AugmentedManifestFile' +} + +export enum S3DataDistributionType { + /** + * Fully replicated S3 Data Distribution Type + */ + FullyReplicated = 'FullyReplicated', + + /** + * Sharded By S3 Key Data Distribution Type + */ + ShardedByS3Key = 'ShardedByS3Key' +} + +export enum RecordWrapperType { + /** + * None record wrapper type + */ + None = 'None', + + /** + * RecordIO record wrapper type + */ + RecordIO = 'RecordIO' +} + +export enum InputMode { + /** + * Pipe mode + */ + Pipe = 'Pipe', + + /** + * File mode. + */ + File = 'File' +} + +export enum CompressionType { + /** + * None compression type + */ + None = 'None', + + /** + * Gzip compression type + */ + Gzip = 'Gzip' +} diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-tasks.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-tasks.ts new file mode 100644 index 0000000000000..f3bd093154872 --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-tasks.ts @@ -0,0 +1,70 @@ +import iam = require('@aws-cdk/aws-iam'); +import sfn = require('@aws-cdk/aws-stepfunctions'); +import cdk = require('@aws-cdk/cdk'); +import { AlgorithmSpecification, Channel, OutputDataConfig, ResourceConfig, StoppingCondition, Tag, VpcConfig } from './sagemaker-base-types'; + +/** + * Basic properties for SageMaker CreateTrainingJob tasks. + */ +export interface CreateTrainingJobkProps { + + /** + * Name of the training job. + */ + readonly trainingJobName: string; + + /** + * Identifies the training algorithm to use. + */ + readonly algorithmSpec: AlgorithmSpecification; + + /** + * Enables encryption between ML compute instances. + */ + readonly enableInterContainerTrafficEncryption?: boolean; + + /** + * Isolates the training container. + */ + readonly enableNetworkIsolation?: boolean; + + /** + * Algorithm-specific parameters that influence the quality of the model. + */ + readonly hyperparameters?: {[key: string]: any}; + + /** + * Array of Channel objects. Each channel is a named input source. + */ + readonly inputDataConfig: Channel[]; + + /** + * Path to the S3 bucket where you want to store model artifacts. + */ + readonly outputDataConfig: OutputDataConfig; + + /** + * Resources to use for model training. + */ + readonly resourceConfig: ResourceConfig; + + /** + * IAM role that Amazon SageMaker can assume to perform tasks on your behalf. + */ + readonly role: iam.Role; + + /** + * Stopping condition + */ + readonly stoppingCondition?: StoppingCondition; + + /** + * Tags + */ + readonly tags?: Tag[]; + + /** + * VPC config + */ + readonly vpcConfig?: VpcConfig; +} \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/package-lock.json b/packages/@aws-cdk/aws-stepfunctions-tasks/package-lock.json index 1abfa69b77ed3..74d861656d1ee 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/package-lock.json +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/package-lock.json @@ -4,6 +4,53 @@ "lockfileVersion": 1, "requires": true, "dependencies": { + "@aws-cdk/aws-kms": { + "version": "0.33.0", + "resolved": "https://registry.npmjs.org/@aws-cdk/aws-kms/-/aws-kms-0.33.0.tgz", + "integrity": "sha512-Yj1i/kqcpLu4LMIfIrk8F1Znereh7kL05+j7Ho0gy+HjwMbODIxB0BcyTiJ/5CjHkJPsR8Tl2GPyerZ6OGk/Dw==", + "requires": { + "@aws-cdk/aws-iam": "^0.33.0", + "@aws-cdk/cdk": "^0.33.0" + }, + "dependencies": { + "@aws-cdk/aws-iam": { + "version": "0.33.0", + "resolved": "https://registry.npmjs.org/@aws-cdk/aws-iam/-/aws-iam-0.33.0.tgz", + "integrity": "sha512-d6HVkScJlG3a0rwWO0LgmZCTndze1c2cpoIezJINZ+sXPyMQiWWyFQDVTDC3LxPPUalG9t42gr2139d2zbfX6w==", + "requires": { + "@aws-cdk/cdk": "^0.33.0", + "@aws-cdk/region-info": "^0.33.0" + } + }, + "@aws-cdk/cdk": { + "version": "0.33.0", + "resolved": "https://registry.npmjs.org/@aws-cdk/cdk/-/cdk-0.33.0.tgz", + "integrity": "sha512-ARfTC6ZTg1r2FOWntYo4kZ3S/Fju2vAagQavll56BJ3EPCxfYbPnIAWu3oFiSzg/4XQ345tbAZP1GSVZsF4RJw==", + "requires": { + "@aws-cdk/cx-api": "^0.33.0" + } + } + } + }, + "@aws-cdk/cx-api": { + "version": "0.33.0", + "resolved": "https://registry.npmjs.org/@aws-cdk/cx-api/-/cx-api-0.33.0.tgz", + "integrity": "sha512-PvPO1quhrezUyYtyi3kEq4CHjmg5TccWQrU4khmTrP9bmb7sNKCmR7ish1VHcA2FBaNjtAj0PgdA+2/+Q+Pzrw==", + "requires": { + "semver": "^6.0.0" + }, + "dependencies": { + "semver": { + "version": "6.1.0", + "bundled": true + } + } + }, + "@aws-cdk/region-info": { + "version": "0.33.0", + "resolved": "https://registry.npmjs.org/@aws-cdk/region-info/-/region-info-0.33.0.tgz", + "integrity": "sha512-Sy0gXDqzGNuOYAF7edd5rlY3iChVSfjaaZ+bONyClF7gulkYv4jehYkQ1ShATl8XsVRedtCOwSU+mDo/tu8npA==" + }, "@babel/code-frame": { "version": "7.0.0", "resolved": "https://registry.npmjs.org/@babel/code-frame/-/code-frame-7.0.0.tgz", diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/package.json b/packages/@aws-cdk/aws-stepfunctions-tasks/package.json index bae176346fa2d..b52d8ba331774 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/package.json +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/package.json @@ -83,6 +83,7 @@ "@aws-cdk/aws-ec2": "^0.33.0", "@aws-cdk/aws-ecs": "^0.33.0", "@aws-cdk/aws-iam": "^0.33.0", + "@aws-cdk/aws-kms": "^0.33.0", "@aws-cdk/aws-lambda": "^0.33.0", "@aws-cdk/aws-sns": "^0.33.0", "@aws-cdk/aws-sqs": "^0.33.0", From b6ce38b282d566542989f56419e5989067e6a2bc Mon Sep 17 00:00:00 2001 From: Matt McClean Date: Fri, 7 Jun 2019 22:50:49 +0100 Subject: [PATCH 2/8] added SageMaker Step Functions tasks --- .../aws-stepfunctions-tasks/lib/index.ts | 1 + .../lib/sagemaker-base-types.ts | 241 +++++++++++- .../lib/sagemaker-task-params.ts | 362 ++++++++++++++++++ .../lib/sagemaker-tasks.ts | 165 ++++---- .../aws-stepfunctions-tasks/package.json | 3 +- .../test/sagemaker-training-job.test.ts | 295 ++++++++++++++ .../test/sagemaker-transform-job.test.ts | 196 ++++++++++ 7 files changed, 1180 insertions(+), 83 deletions(-) create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-params.ts create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts index 9601f8a35a51c..3fbac70910c90 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts @@ -7,4 +7,5 @@ export * from './send-to-queue'; export * from './run-ecs-ec2-task'; export * from './run-ecs-fargate-task'; export * from './sagemaker-base-types'; +export * from './sagemaker-task-params'; export * from './sagemaker-tasks'; \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-base-types.ts index 79bcdc7811a52..bb3ea9eabd17b 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-base-types.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-base-types.ts @@ -1,6 +1,13 @@ import ec2 = require('@aws-cdk/aws-ec2'); import kms = require('@aws-cdk/aws-kms'); +// +// Create Training Job types +// + +/** + * Identifies the training algorithm to use. + */ export interface AlgorithmSpecification { /** @@ -21,9 +28,12 @@ export interface AlgorithmSpecification { /** * Input mode that the algorithm supports. */ - readonly trainingInputMode: InputMode.File | InputMode.Pipe; + readonly trainingInputMode: InputMode; } +/** + * Describes the training, validation or test dataset and the Amazon S3 location where it is stored. + */ export interface Channel { /** @@ -34,7 +44,7 @@ export interface Channel { /** * Compression type if training data is compressed */ - readonly compressionType?: CompressionType.None | CompressionType.Gzip; + readonly compressionType?: CompressionType; /** * Content type @@ -49,12 +59,12 @@ export interface Channel { /** * Input mode to use for the data channel in a training job. */ - readonly inputMode?: InputMode.File | InputMode.Pipe; + readonly inputMode?: InputMode; /** * Record wrapper type */ - readonly recordWrapperType?: RecordWrapperType.None | RecordWrapperType.RecordIO; + readonly recordWrapperType?: RecordWrapperType; /** * Shuffle config option for input data in a channel. @@ -62,6 +72,9 @@ export interface Channel { readonly shuffleConfig?: ShuffleConfig; } +/** + * Configuration for a shuffle option for input data in a channel. + */ export interface ShuffleConfig { /** * Determines the shuffling order. @@ -69,6 +82,9 @@ export interface ShuffleConfig { readonly seed: number; } +/** + * Location of the channel data. + */ export interface DataSource { /** * S3 location of the data source that is associated with a channel. @@ -76,6 +92,9 @@ export interface DataSource { readonly s3DataSource: S3DataSource; } +/** + * S3 location of the channel data. + */ export interface S3DataSource { /** * List of one or more attribute names to use that are found in a specified augmented manifest file. @@ -85,7 +104,7 @@ export interface S3DataSource { /** * S3 Data Distribution Type */ - readonly s3DataDistributionType?: S3DataDistributionType.FullyReplicated | S3DataDistributionType.ShardedByS3Key; + readonly s3DataDistributionType?: S3DataDistributionType; /** * S3 Data Type @@ -98,6 +117,9 @@ export interface S3DataSource { readonly s3Uri: string; } +/** + * Identifies the Amazon S3 location where you want Amazon SageMaker to save the results of model training. + */ export interface OutputDataConfig { /** * Optional KMS encryption key that Amazon SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side encryption. @@ -110,6 +132,9 @@ export interface OutputDataConfig { readonly s3OutputPath: string; } +/** + * Sets a time limit for training. + */ export interface StoppingCondition { /** * The maximum length of time, in seconds, that the training or compilation job can run. @@ -117,18 +142,9 @@ export interface StoppingCondition { readonly maxRuntimeInSeconds?: number; } -export interface Tag { - /** - * Key tag. - */ - readonly key: string; - - /** - * Value tag. - */ - readonly value: string; -} - +/** + * Identifies the resources, ML compute instances, and ML storage volumes to deploy for model training. + */ export interface ResourceConfig { /** @@ -152,18 +168,29 @@ export interface ResourceConfig { readonly volumeSizeInGB: number; } +/** + * Specifies the VPC that you want your training job to connect to. + */ export interface VpcConfig { /** * VPC security groups. */ readonly securityGroups: ec2.ISecurityGroup[]; + /** + * VPC id + */ + readonly vpc: ec2.Vpc; + /** * VPC subnets. */ - readonly subnetSelection: ec2.SubnetSelection; + readonly subnets: ec2.ISubnet[]; } +/** + * Specifies the metric name and regular expressions used to parse algorithm logs. + */ export interface MetricDefinition { /** @@ -177,6 +204,9 @@ export interface MetricDefinition { readonly regex: string; } +/** + * S3 Data Type. + */ export enum S3DataType { /** * Manifest File Data Type @@ -194,6 +224,9 @@ export enum S3DataType { AugmentedManifestFile = 'AugmentedManifestFile' } +/** + * S3 Data Distribution Type. + */ export enum S3DataDistributionType { /** * Fully replicated S3 Data Distribution Type @@ -206,6 +239,9 @@ export enum S3DataDistributionType { ShardedByS3Key = 'ShardedByS3Key' } +/** + * Define the format of the input data. + */ export enum RecordWrapperType { /** * None record wrapper type @@ -218,6 +254,9 @@ export enum RecordWrapperType { RecordIO = 'RecordIO' } +/** + * Input mode that the algorithm supports. + */ export enum InputMode { /** * Pipe mode @@ -230,6 +269,9 @@ export enum InputMode { File = 'File' } +/**Compression type of the data. + * + */ export enum CompressionType { /** * None compression type @@ -241,3 +283,166 @@ export enum CompressionType { */ Gzip = 'Gzip' } + +// +// Create Transform Job types +// + +/** + * Dataset to be transformed and the Amazon S3 location where it is stored. + */ +export interface TransformInput { + + /** + * The compression type of the transform data. + */ + readonly compressionType?: CompressionType; + + /** + * Multipurpose internet mail extension (MIME) type of the data. + */ + readonly contentType?: string; + + /** + * S3 location of the channel data + */ + readonly transformDataSource: TransformDataSource; + + /** + * + */ + readonly splitType?: SplitType; +} + +/** + * S3 location of the input data that the model can consume. + */ +export interface TransformDataSource { + + /** + * S3 location of the input data + */ + readonly s3DataSource: TransformS3DataSource; +} + +/** + * Location of the channel data. + */ +export interface TransformS3DataSource { + + /** + * S3 Data Type + */ + readonly s3DataType: S3DataType; + + /** + * Identifies either a key name prefix or a manifest. + */ + readonly s3Uri: string; +} + +/** + * S3 location where you want Amazon SageMaker to save the results from the transform job. + */ +export interface TransformOutput { + + /** + * MIME type used to specify the output data. + */ + readonly accept?: string; + + /** + * Defines how to assemble the results of the transform job as a single S3 object. + */ + readonly assembleWith?: AssembleWith; + + /** + * AWS KMS key that Amazon SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side encryption. + */ + readonly encryptionKey?: kms.Key; + + /** + * S3 path where you want Amazon SageMaker to store the results of the transform job. + */ + readonly s3OutputPath: string; +} + +/** + * ML compute instances for the transform job. + */ +export interface TransformResources { + + /** + * Nmber of ML compute instances to use in the transform job + */ + readonly instanceCount: number; + + /** + * ML compute instance type for the transform job. + */ + readonly instanceType: ec2.InstanceType; + + /** + * AWS KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s). + */ + readonly volumeKmsKeyId?: kms.Key; +} + +/** + * Specifies the number of records to include in a mini-batch for an HTTP inference request. + */ +export enum BatchStrategy { + + /** + * Fits multiple records in a mini-batch. + */ + MultiRecord = 'MultiRecord', + + /** + * Use a single record when making an invocation request. + */ + SingleRecord = 'SingleRecord' +} + +/** + * Method to use to split the transform job's data files into smaller batches. + */ +export enum SplitType { + + /** + * Input data files are not split, + */ + None = 'None', + + /** + * Split records on a newline character boundary. + */ + Line = 'Line', + + /** + * Split using MXNet RecordIO format. + */ + RecordIO = 'RecordIO', + + /** + * Split using TensorFlow TFRecord format. + */ + TFRecord = 'TFRecord' +} + +/** + * How to assemble the results of the transform job as a single S3 object. + */ +export enum AssembleWith { + + /** + * Concatenate the results in binary format. + */ + None = 'None', + + /** + * Add a newline character at the end of every transformed record. + */ + Line = 'Line' + +} \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-params.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-params.ts new file mode 100644 index 0000000000000..8d9058d1052b2 --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-params.ts @@ -0,0 +1,362 @@ +import iam = require('@aws-cdk/aws-iam'); +import cdk = require('@aws-cdk/cdk'); +import { AlgorithmSpecification, BatchStrategy,Channel, OutputDataConfig, ResourceConfig, + TransformInput, TransformOutput, TransformResources, VpcConfig } from './sagemaker-base-types'; + +/** + * Parameters for the SageMaker Training Job. + */ +export class TrainingJobParameters extends cdk.Token { + + public trainingJobName: string; + public role: iam.Role; + private algorithmSpec: {[key: string]: any} = {}; + private hyperparameters: {[key: string]: any}; + private inputDataConfig = new Array(); + private outputDataConfig: {[key: string]: any} = {}; + private resourceConfig: {[key: string]: any} = {}; + private stoppingCondition: {[key: string]: any} = {}; + private tags = new Array(); + private vpcConfig: {[key: string]: any}; + + constructor(name: string, role: iam.Role) { + super(); + this.trainingJobName = name; + this.role = role; + } + + /** + * Adds the Alogorithm Specification config. + */ + public addAlgorithmSpec(spec: AlgorithmSpecification): TrainingJobParameters { + this.algorithmSpec = { + TrainingInputMode: spec.trainingInputMode, + ...(spec.trainingImage) ? { TrainingImage: spec.trainingImage } : {}, + ...(spec.algorithmName) ? { AlgorithmName: spec.algorithmName } : {}, + }; + if (spec.metricDefinitions) { + this.algorithmSpec.MetricDefinitions = []; + spec.metricDefinitions.forEach(metric => this.algorithmSpec.MetricDefinitions.push( { Name: metric.name, Regex: metric.regex } )); + } + return this; + } + + /** + * Add a hyperparameter to the config to the parameters. + */ + public addHyperparameter(key: string, value: string): TrainingJobParameters { + if (! this.hyperparameters) { + this.hyperparameters = {}; + } + this.hyperparameters[key] = value; + return this; + } + + /** + * Add multiple hyperparameters to the config to the parameters. + */ + public addHyperparameters(params: {[key: string]: any}): TrainingJobParameters { + Object.keys(params).map(key => { + this.addHyperparameter(key, params[key]); + }); + return this; + } + + /** + * Add an InputDataConfig to the list of config objects to the parameters. + */ + public addInputDataConfig(config: Channel): TrainingJobParameters { + this.inputDataConfig.push({ + ChannelName: config.channelName, + DataSource: { + S3DataSource: { + S3Uri: config.dataSource.s3DataSource.s3Uri, + S3DataType: config.dataSource.s3DataSource.s3DataType, + ...(config.dataSource.s3DataSource.s3DataDistributionType) ? + { S3DataDistributionType: config.dataSource.s3DataSource.s3DataDistributionType} : {}, + ...(config.dataSource.s3DataSource.attributeNames) ? { AtttributeNames: config.dataSource.s3DataSource.attributeNames } : {}, + } + }, + ...(config.compressionType) ? { CompressionType: config.compressionType } : {}, + ...(config.contentType) ? { ContentType: config.contentType } : {}, + ...(config.inputMode) ? { InputMode: config.inputMode } : {}, + ...(config.recordWrapperType) ? { RecordWrapperType: config.recordWrapperType } : {}, + }); + return this; + } + + /** + * Add a list of InputDataConfig objects to the parameters. + */ + public addInputDataConfigs(configs: Channel[]): TrainingJobParameters { + configs.forEach(config => this.addInputDataConfig(config)); + return this; + } + + /** + * Add a single tag pair to the parameters. + */ + public addTag(name: string, value: string): TrainingJobParameters { + this.tags.push({ Key: name, Value: value }); + return this; + } + + /** + * Add multiple tag pairs to the parameters. + */ + public addTags(tags: {[key: string]: any}): TrainingJobParameters { + Object.keys(tags).map(key => { + this.addTag(key, tags[key]); + }); + return this; + } + + /** + * Adds the Output Data Config to the parameters. + */ + public addOutputDataConfig(config: OutputDataConfig): TrainingJobParameters { + this.outputDataConfig = { + S3OutputPath: config.s3OutputPath, + ...(config.encryptionKey) ? { KmsKeyId: config.encryptionKey.keyArn } : {}, + }; + return this; + } + + /** + * Add a Resource Config to the parameters. + */ + public addResourceConfig(config: ResourceConfig): TrainingJobParameters { + this.resourceConfig = { + InstanceCount: config.instanceCount, + InstanceType: 'ml.' + config.instanceType, + VolumeSizeInGB: config.volumeSizeInGB, + ...(config.volumeKmsKeyId) ? { VolumeKmsKeyId: config.volumeKmsKeyId.keyArn } : {}, + }; + return this; + } + + /** + * Add a Stopping Condition to the parameters. + */ + public addStoppingCondition(maxRuntime: number): TrainingJobParameters { + this.stoppingCondition = { MaxRuntimeInSeconds: maxRuntime }; + return this; + } + + /** + * Add a VPC config to the parameters. + */ + public addVpcConfig(config: VpcConfig): TrainingJobParameters { + this.vpcConfig = { + SecurityGroupIds: [], + Subnets: [] + }; + config.securityGroups.forEach(sg => this.vpcConfig.SecurityGroupIds.push(sg.securityGroupId)); + config.subnets.forEach(subnet => this.vpcConfig.Subnets.push(subnet.subnetId)); + return this; + } + + // + // Serialization + // + public resolve(_context: cdk.IResolveContext): any { + return this.toJson(); + } + + public toJson(): any { + if (Object.entries(this.algorithmSpec).length === 0) { + throw new Error("Mandatory parameter 'AlgorithmSpecification' is empty"); + } + + if (Object.entries(this.inputDataConfig).length === 0) { + throw new Error("Mandatory parameter 'InputDataConfig' is empty"); + } + + if (Object.entries(this.outputDataConfig).length === 0) { + throw new Error("Mandatory parameter 'OutputDataConfig' is empty"); + } + + if (Object.entries(this.resourceConfig).length === 0) { + throw new Error("Mandatory parameter 'ResourceConfig' is empty"); + } + + if (Object.entries(this.stoppingCondition).length === 0) { + throw new Error("Mandatory parameter 'StoppingCondition' is empty"); + } + + return { + TrainingJobName: this.trainingJobName, + RoleArn: this.role.roleArn, + AlgorithmSpecification: this.algorithmSpec, + ...(this.hyperparameters) ? { HyperParameters: this.hyperparameters } : {}, + InputDataConfig: this.inputDataConfig, + OutputDataConfig: this.outputDataConfig, + ResourceConfig: this.resourceConfig, + StoppingCondition: this.stoppingCondition, + ...(this.tags.length > 0) ? { Tags: this.tags } : {}, + ...(this.vpcConfig) ? { VpcConfig: this.vpcConfig} : {}, + }; + } +} + +/** + * A class holding the SageMalker Transform Job parameters. + */ +export class TransformJobParameters extends cdk.Token { + + public transformJobName: string; + public role: iam.Role; + private batchStrategy: string; + private environmentVars: {[key: string]: any}; + private maxConcurrentTransforms: number; + private maxPayloadInMB: number; + private modelName: string; + private tags = new Array(); + private transformInput: {[key: string]: any} = {}; + private transformOutput: {[key: string]: any} = {}; + private transformResources: {[key: string]: any} = {}; + + constructor(jobName: string, modelName: string, role: iam.Role) { + super(); + this.transformJobName = jobName; + this.modelName = modelName; + this.role = role; + } + + /** + * Add a Batch strategy to the parameters. + */ + public addBatchStrategy(strategy: BatchStrategy): TransformJobParameters { + this.batchStrategy = strategy; + return this; + } + + /** + * Add an environment variable pair to the parameters. + */ + public addEnvironmentVar(key: string, value: string): TransformJobParameters { + if (! this.environmentVars) { + this.environmentVars = {}; + } + this.environmentVars[key] = value; + return this; + } + + /** + * Add multiple environment variable pairs to the parameters. + */ + public addEnvironmentVars(envars: {[key: string]: any}): TransformJobParameters { + Object.keys(envars).map(key => { + this.addEnvironmentVar(key, envars[key]); + }); + return this; + } + + /** + * Add a max concurrent transforms value to the parameters. + */ + public addMaxConcurrentTransforms(max: number): TransformJobParameters { + this.maxConcurrentTransforms = max; + return this; + } + + /** + * Add a max payload in MB to the parameters. + */ + public addMxaxPayloadInMB(max: number): TransformJobParameters { + this.maxPayloadInMB = max; + return this; + } + + /** + * Add an Transform Input config to the parameters. + */ + public addTransformInput(input: TransformInput): TransformJobParameters { + this.transformInput = { + DataSource: { + S3DataSource: { + S3Uri: input.transformDataSource.s3DataSource.s3Uri, + S3DataType: input.transformDataSource.s3DataSource.s3DataType, + } + }, + ...(input.compressionType) ? { CompressionType: input.compressionType } : {}, + ...(input.contentType) ? { ContentType: input.contentType } : {}, + ...(input.splitType) ? { SplitType: input.splitType } : {}, + }; + return this; + } + + /** + * Add a single tag pair to the parameters. + */ + public addTag(name: string, value: string): TransformJobParameters { + this.tags.push({ Key: name, Value: value }); + return this; + } + + /** + * Add multiple tag pairs + */ + public addTags(tags: {[key: string]: any}): TransformJobParameters { + Object.keys(tags).map(key => { + this.addTag(key, tags[key]); + }); + return this; + } + + /** + * Add a Transform Output config to the parameters. + */ + public addTransformOutput(output: TransformOutput): TransformJobParameters { + this.transformOutput = { + S3OutputPath: output.s3OutputPath, + ...(output.encryptionKey) ? { KmsKeyId: output.encryptionKey.keyArn } : {}, + ...(output.accept) ? { Accept: output.accept } : {}, + ...(output.assembleWith) ? { AssembleWith: output.assembleWith } : {}, + }; + return this; + } + + /** + * Add a Transform Resource config to the parameters. + */ + public addTransformResources(resource: TransformResources): TransformJobParameters { + this.transformResources = { + InstanceCount: resource.instanceCount, + InstanceType: 'ml.' + resource.instanceType, + ...(resource.volumeKmsKeyId) ? { VolumeKmsKeyId: resource.volumeKmsKeyId.keyArn } : {}, + }; + return this; + } + + public resolve(_context: cdk.IResolveContext): any { + return this.toJson(); + } + + public toJson(): any { + if (Object.entries(this.transformInput).length === 0) { + throw new Error("Mandatory parameter 'TransformInput' is empty"); + } + + if (Object.entries(this.transformOutput).length === 0) { + throw new Error("Mandatory parameter 'TransformOutput' is empty"); + } + + if (Object.entries(this.transformResources).length === 0) { + throw new Error("Mandatory parameter 'TransformResources' is empty"); + } + + return { + ...(this.batchStrategy) ? { BatchStrategy: this.batchStrategy} : {}, + ...(this.environmentVars) ? { Environment: this.environmentVars} : {}, + ...(this.maxConcurrentTransforms) ? { MaxConcurrentTransforms: this.maxConcurrentTransforms} : {}, + ...(this.maxPayloadInMB) ? { MaxPayloadInMB: this.maxPayloadInMB } : {}, + ModelName: this.modelName, + ...(this.tags.length > 0) ? { Tags: this.tags } : {}, + TransformInput: this.transformInput, + TransformJobName: this.transformJobName, + TransformOutput: this.transformOutput, + TransformResources: this.transformResources, + }; + } +} \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-tasks.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-tasks.ts index f3bd093154872..cb45e2681401e 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-tasks.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-tasks.ts @@ -1,70 +1,107 @@ +import ec2 = require('@aws-cdk/aws-ec2'); import iam = require('@aws-cdk/aws-iam'); import sfn = require('@aws-cdk/aws-stepfunctions'); -import cdk = require('@aws-cdk/cdk'); -import { AlgorithmSpecification, Channel, OutputDataConfig, ResourceConfig, StoppingCondition, Tag, VpcConfig } from './sagemaker-base-types'; + +import { TrainingJobParameters, TransformJobParameters } from './sagemaker-task-params'; /** - * Basic properties for SageMaker CreateTrainingJob tasks. + * Class representing the SageMaker Create Training Job task. */ -export interface CreateTrainingJobkProps { - - /** - * Name of the training job. - */ - readonly trainingJobName: string; - - /** - * Identifies the training algorithm to use. - */ - readonly algorithmSpec: AlgorithmSpecification; - - /** - * Enables encryption between ML compute instances. - */ - readonly enableInterContainerTrafficEncryption?: boolean; - - /** - * Isolates the training container. - */ - readonly enableNetworkIsolation?: boolean; - - /** - * Algorithm-specific parameters that influence the quality of the model. - */ - readonly hyperparameters?: {[key: string]: any}; - - /** - * Array of Channel objects. Each channel is a named input source. - */ - readonly inputDataConfig: Channel[]; - - /** - * Path to the S3 bucket where you want to store model artifacts. - */ - readonly outputDataConfig: OutputDataConfig; - - /** - * Resources to use for model training. - */ - readonly resourceConfig: ResourceConfig; - - /** - * IAM role that Amazon SageMaker can assume to perform tasks on your behalf. - */ - readonly role: iam.Role; - - /** - * Stopping condition - */ - readonly stoppingCondition?: StoppingCondition; - - /** - * Tags - */ - readonly tags?: Tag[]; - - /** - * VPC config - */ - readonly vpcConfig?: VpcConfig; +export class SagemakerTrainingJobTask implements ec2.IConnectable, sfn.IStepFunctionsTask { + + public readonly connections: ec2.Connections = new ec2.Connections(); + + constructor(private readonly parameters: TrainingJobParameters, private readonly sync: boolean = false) {} + + public bind(task: sfn.Task): sfn.StepFunctionsTaskProperties { + return { + resourceArn: 'arn:aws:states:::sagemaker:createTrainingJob' + (this.sync ? '.sync' : ''), + parameters: sfn.FieldUtils.renderObject(this.parameters.toJson()), + policyStatements: this.makePolicyStatements(task), + }; + } + + private makePolicyStatements(task: sfn.Task): iam.PolicyStatement[] { + const stack = task.node.stack; + + // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html + const policyStatements = [ + new iam.PolicyStatement() + .addActions('sagemaker:CreateTrainingJob', 'sagemaker:DescribeTrainingJob', 'sagemaker:StopTrainingJob') + .addResource(stack.formatArn({ + service: 'sagemaker', + resource: 'training-job', + resourceName: '*' + })), + new iam.PolicyStatement() + .addAction('sagemaker:ListTags') + .addAllResources(), + new iam.PolicyStatement() + .addAction('iam:PassRole') + .addResources(this.parameters.role.roleArn) + .addCondition('StringEquals', { "iam:PassedToService": "sagemaker.amazonaws.com" }) + ]; + + if (this.sync) { + policyStatements.push(new iam.PolicyStatement() + .addActions("events:PutTargets", "events:PutRule", "events:DescribeRule") + .addResource(stack.formatArn({ + service: 'events', + resource: 'rule', + resourceName: 'StepFunctionsGetEventsForSageMakerTrainingJobsRule' + }))); + } + + return policyStatements; + } +} + +/** + * Class representing the SageMaker Create Transform Job task. + */ +export class SagemakerTransformJobTask implements sfn.IStepFunctionsTask { + + constructor(private readonly parameters: TransformJobParameters, private readonly sync: boolean = false) {} + + public bind(task: sfn.Task): sfn.StepFunctionsTaskProperties { + return { + resourceArn: 'arn:aws:states:::sagemaker:createTransformJob' + (this.sync ? '.sync' : ''), + parameters: sfn.FieldUtils.renderObject(this.parameters.toJson()), + policyStatements: this.makePolicyStatements(task), + }; + } + + private makePolicyStatements(task: sfn.Task): iam.PolicyStatement[] { + const stack = task.node.stack; + + // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html + const policyStatements = [ + new iam.PolicyStatement() + .addActions('sagemaker:CreateTransformJob', 'sagemaker:DescribeTransformJob', 'sagemaker:StopTransformJob') + .addResource(stack.formatArn({ + service: 'sagemaker', + resource: 'transform-job', + resourceName: '*' + })), + new iam.PolicyStatement() + .addAction('sagemaker:ListTags') + .addAllResources(), + new iam.PolicyStatement() + .addAction('iam:PassRole') + .addResources(this.parameters.role.roleArn) + .addCondition('StringEquals', { "iam:PassedToService": "sagemaker.amazonaws.com" }) + ]; + + if (this.sync) { + policyStatements.push(new iam.PolicyStatement() + .addActions("events:PutTargets", "events:PutRule", "events:DescribeRule") + .addResource(stack.formatArn({ + service: 'events', + resource: 'rule', + resourceName: 'StepFunctionsGetEventsForSageMakerTransformJobsRule' + }))); + } + + return policyStatements; + } } \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/package.json b/packages/@aws-cdk/aws-stepfunctions-tasks/package.json index b52d8ba331774..349b58a541f2b 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/package.json +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/package.json @@ -100,7 +100,8 @@ "@aws-cdk/aws-sns": "^0.33.0", "@aws-cdk/aws-sqs": "^0.33.0", "@aws-cdk/aws-stepfunctions": "^0.33.0", - "@aws-cdk/cdk": "^0.33.0" + "@aws-cdk/cdk": "^0.33.0", + "@aws-cdk/aws-kms": "^0.33.0" }, "engines": { "node": ">= 8.10.0" diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts new file mode 100644 index 0000000000000..628d8a40390b6 --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts @@ -0,0 +1,295 @@ +import '@aws-cdk/assert/jest'; +import ec2 = require('@aws-cdk/aws-ec2'); +import iam = require('@aws-cdk/aws-iam'); +import kms = require('@aws-cdk/aws-kms'); +import sfn = require('@aws-cdk/aws-stepfunctions'); +import cdk = require('@aws-cdk/cdk'); +import tasks = require('../lib'); + +let stack: cdk.Stack; +let role: iam.Role; + +beforeEach(() => { + // GIVEN + stack = new cdk.Stack(); + role = new iam.Role(stack, 'Role', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicyArns: [ + new iam.AwsManagedPolicy('AmazonSageMakerFullAccess', stack).policyArn + ], + }); + }); + +test('create basic training job', () => { + // WHEN + const params = new tasks.TrainingJobParameters("MyTrainJob", role); + params.addAlgorithmSpec({ algorithmName: "BlazingText", trainingInputMode: tasks.InputMode.File}) + .addInputDataConfig( + { + channelName: "train", + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3Prefix, + s3Uri: "s3://mybucket/mytrainpath" + } + } + }) + .addOutputDataConfig({ s3OutputPath: 's3://mybucket/myoutputpath' }) + .addResourceConfig( + { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + volumeSizeInGB: 50 + }) + .addStoppingCondition(3600); + + const pub = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainingJobTask(params) }); + + // THEN + expect(stack.node.resolve(pub.toStateJson())).toEqual({ + Type: 'Task', + Resource: 'arn:aws:states:::sagemaker:createTrainingJob', + End: true, + Parameters: { + TrainingJobName: 'MyTrainJob', + RoleArn: { "Fn::GetAtt": [ "Role1ABCC5F0", "Arn" ] }, + AlgorithmSpecification: { + TrainingInputMode: 'File', + AlgorithmName: 'BlazingText', + }, + InputDataConfig: [ + { + ChannelName: 'train', + DataSource: { + S3DataSource: { + S3DataType: 'S3Prefix', + S3Uri: 's3://mybucket/mytrainpath' + } + } + } + ], + OutputDataConfig: { + S3OutputPath: 's3://mybucket/myoutputpath' + }, + ResourceConfig: { + InstanceCount: 1, + InstanceType: 'ml.p3.2xlarge', + VolumeSizeInGB: 50 + }, + StoppingCondition: { + MaxRuntimeInSeconds: 3600 + } + }, + }); +}); + +test('create complex training job', () => { + // WHEN + const kmsKey = new kms.Key(stack, 'Key'); + const vpc = new ec2.Vpc(stack, "VPC"); + const securityGroup = new ec2.SecurityGroup(stack, 'SecurityGroup', { vpc, description: 'My SG' }); + securityGroup.addIngressRule(new ec2.AnyIPv4(), new ec2.TcpPort(22), 'allow ssh access from the world'); + + const params = new tasks.TrainingJobParameters("MyTrainJob", role); + params.addAlgorithmSpec( + { + algorithmName: "BlazingText", + trainingInputMode: tasks.InputMode.File, + metricDefinitions: [ + { + name: 'mymetric', regex: 'regex_pattern' + } + ] + }) + .addHyperparameter("lr", "0.1" ) + .addInputDataConfigs([ + { + channelName: "train", + contentType: "image/jpeg", + compressionType: tasks.CompressionType.None, + recordWrapperType: tasks.RecordWrapperType.RecordIO, + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3Prefix, + s3Uri: "s3://mybucket/mytrainpath", + } + } + }, + { + channelName: "test", + contentType: "image/jpeg", + compressionType: tasks.CompressionType.Gzip, + recordWrapperType: tasks.RecordWrapperType.RecordIO, + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3Prefix, + s3Uri: "s3://mybucket/mytestpath", + } + } + } + ]) + .addOutputDataConfig({ s3OutputPath: 's3://mybucket/myoutputpath', encryptionKey: kmsKey }) + .addResourceConfig( + { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + volumeSizeInGB: 50, + volumeKmsKeyId: kmsKey, + }) + .addStoppingCondition(3600) + .addTag("Project", "MyProject") + .addVpcConfig({ vpc, subnets: vpc.privateSubnets, securityGroups: [ securityGroup ] }); + + const pub = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainingJobTask(params, true) }); + + // THEN + expect(stack.node.resolve(pub.toStateJson())).toEqual({ + Type: 'Task', + Resource: 'arn:aws:states:::sagemaker:createTrainingJob.sync', + End: true, + Parameters: { + TrainingJobName: 'MyTrainJob', + RoleArn: { "Fn::GetAtt": [ "Role1ABCC5F0", "Arn" ] }, + AlgorithmSpecification: { + TrainingInputMode: 'File', + AlgorithmName: 'BlazingText', + MetricDefinitions: [ + { Name: "mymetric", Regex: "regex_pattern" } + ] + }, + HyperParameters: { + lr: "0.1" + }, + InputDataConfig: [ + { + ChannelName: 'train', + CompressionType: 'None', + RecordWrapperType: 'RecordIO', + ContentType: 'image/jpeg', + DataSource: { + S3DataSource: { + S3DataType: 'S3Prefix', + S3Uri: 's3://mybucket/mytrainpath' + } + } + }, + { + ChannelName: 'test', + CompressionType: 'Gzip', + RecordWrapperType: 'RecordIO', + ContentType: 'image/jpeg', + DataSource: { + S3DataSource: { + S3DataType: 'S3Prefix', + S3Uri: 's3://mybucket/mytestpath' + } + } + } + ], + OutputDataConfig: { + S3OutputPath: 's3://mybucket/myoutputpath', + KmsKeyId: { "Fn::GetAtt": [ "Key961B73FD", "Arn" ] }, + }, + ResourceConfig: { + InstanceCount: 1, + InstanceType: 'ml.p3.2xlarge', + VolumeSizeInGB: 50, + VolumeKmsKeyId: { "Fn::GetAtt": [ "Key961B73FD", "Arn" ] }, + }, + StoppingCondition: { + MaxRuntimeInSeconds: 3600 + }, + Tags: [ + { Key: "Project", Value: "MyProject" } + ], + VpcConfig: { + SecurityGroupIds: [ { "Fn::GetAtt": [ "SecurityGroupDD263621", "GroupId" ] } ], + Subnets: [ + { Ref: "VPCPrivateSubnet1Subnet8BCA10E0" }, + { Ref: "VPCPrivateSubnet2SubnetCFCDAA7A" }, + { Ref: "VPCPrivateSubnet3Subnet3EDCD457" } + ] + } + }, + }); +}); + +test('pass param to training job', () => { + // WHEN + const params = new tasks.TrainingJobParameters(sfn.Data.stringAt('$.JobName'), role); + params.addAlgorithmSpec({ algorithmName: "BlazingText", trainingInputMode: tasks.InputMode.File}) + .addInputDataConfig( + { + channelName: "train", + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3Prefix, + s3Uri: sfn.Data.stringAt('$.S3Bucket') + } + } + }) + .addOutputDataConfig({ s3OutputPath: 's3://mybucket/myoutputpath' }) + .addResourceConfig( + { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + volumeSizeInGB: 50 + }) + .addStoppingCondition(3600); + + const pub = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainingJobTask(params) }); + + // THEN + expect(stack.node.resolve(pub.toStateJson())).toEqual({ + Type: 'Task', + Resource: 'arn:aws:states:::sagemaker:createTrainingJob', + End: true, + Parameters: { + 'TrainingJobName.$': '$.JobName', + 'RoleArn': { "Fn::GetAtt": [ "Role1ABCC5F0", "Arn" ] }, + 'AlgorithmSpecification': { + 'TrainingInputMode': 'File', + 'AlgorithmName': 'BlazingText', + }, + 'InputDataConfig': [ + { + 'ChannelName': 'train', + 'DataSource': { + 'S3DataSource': { + 'S3DataType': 'S3Prefix', + 'S3Uri.$': '$.S3Bucket' + } + } + } + ], + 'OutputDataConfig': { + 'S3OutputPath': 's3://mybucket/myoutputpath' + }, + 'ResourceConfig': { + 'InstanceCount': 1, + 'InstanceType': 'ml.p3.2xlarge', + 'VolumeSizeInGB': 50 + }, + 'StoppingCondition': { + 'MaxRuntimeInSeconds': 3600 + } + }, + }); +}); + +test('throw error when mandatory parameter not found', () => { + // WHEN + const params = new tasks.TrainingJobParameters(sfn.Data.stringAt('$.JobName'), role); + params.addAlgorithmSpec({ algorithmName: "BlazingText", trainingInputMode: tasks.InputMode.File}) + .addOutputDataConfig({ s3OutputPath: 's3://mybucket/myoutputpath' }) + .addResourceConfig( + { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + volumeSizeInGB: 50 + }) + .addStoppingCondition(3600); + + // THEN + expect(() => params.toJson()).toThrow(); +}); \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts new file mode 100644 index 0000000000000..abca2727454d2 --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts @@ -0,0 +1,196 @@ +import '@aws-cdk/assert/jest'; +import ec2 = require('@aws-cdk/aws-ec2'); +import iam = require('@aws-cdk/aws-iam'); +import kms = require('@aws-cdk/aws-kms'); +import sfn = require('@aws-cdk/aws-stepfunctions'); +import cdk = require('@aws-cdk/cdk'); +import tasks = require('../lib'); +import { S3DataType, BatchStrategy } from '../lib'; + +let stack: cdk.Stack; +let role: iam.Role; + +beforeEach(() => { + // GIVEN + stack = new cdk.Stack(); + role = new iam.Role(stack, 'Role', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicyArns: [ + new iam.AwsManagedPolicy('AmazonSageMakerFullAccess', stack).policyArn + ], + }); + }); + +test('create basic transform job', () => { + // WHEN + const params = new tasks.TransformJobParameters("MyTransformJob", "MyModelName", role); + params.addTransformInput( + { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/prefix', + s3DataType: S3DataType.S3Prefix, + } + } + }) + .addTransformOutput({ + s3OutputPath: 's3://outputbucket/prefix', + }) + .addTransformResources({ + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + }); + + const pub = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformJobTask(params) }); + + // THEN + expect(stack.node.resolve(pub.toStateJson())).toEqual({ + Type: 'Task', + Resource: 'arn:aws:states:::sagemaker:createTransformJob', + End: true, + Parameters: { + TransformJobName: 'MyTransformJob', + ModelName: 'MyModelName', + TransformInput: { + DataSource: { + S3DataSource: { + S3Uri: 's3://inputbucket/prefix', + S3DataType: 'S3Prefix', + } + } + }, + TransformOutput: { + S3OutputPath: 's3://outputbucket/prefix', + }, + TransformResources: { + InstanceCount: 1, + InstanceType: 'ml.p3.2xlarge', + } + }, + }); +}); + +test('create complex transform job', () => { + // WHEN + const kmsKey = new kms.Key(stack, 'Key'); + const params = new tasks.TransformJobParameters("MyTransformJob", "MyModelName", role); + params.addTransformInput( + { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/prefix', + s3DataType: S3DataType.S3Prefix, + } + } + }) + .addTransformOutput({ + s3OutputPath: 's3://outputbucket/prefix', + encryptionKey: kmsKey, + }) + .addTransformResources({ + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + volumeKmsKeyId: kmsKey, + }) + .addTag('Project', 'MyProject') + .addBatchStrategy(BatchStrategy.MultiRecord) + .addEnvironmentVar('SOMEVAR', 'myvalue') + .addMaxConcurrentTransforms(3) + .addMxaxPayloadInMB(100); + + const pub = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformJobTask(params, true) }); + + // THEN + expect(stack.node.resolve(pub.toStateJson())).toEqual({ + Type: 'Task', + Resource: 'arn:aws:states:::sagemaker:createTransformJob.sync', + End: true, + Parameters: { + TransformJobName: 'MyTransformJob', + ModelName: 'MyModelName', + TransformInput: { + DataSource: { + S3DataSource: { + S3Uri: 's3://inputbucket/prefix', + S3DataType: 'S3Prefix', + } + } + }, + TransformOutput: { + S3OutputPath: 's3://outputbucket/prefix', + KmsKeyId: { "Fn::GetAtt": [ "Key961B73FD", "Arn" ] }, + }, + TransformResources: { + InstanceCount: 1, + InstanceType: 'ml.p3.2xlarge', + VolumeKmsKeyId: { "Fn::GetAtt": [ "Key961B73FD", "Arn" ] }, + }, + Tags: [ + { Key: 'Project', Value: 'MyProject' } + ], + MaxConcurrentTransforms: 3, + MaxPayloadInMB: 100, + Environment: { + SOMEVAR: 'myvalue' + }, + BatchStrategy: 'MultiRecord' + }, + }); +}); + +test('pass param to transform job', () => { + // WHEN + const params = new tasks.TransformJobParameters(sfn.Data.stringAt('$.TransformJobName'), sfn.Data.stringAt('$.ModelName'), role); + params.addTransformInput( + { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/prefix', + s3DataType: S3DataType.S3Prefix, + } + } + }) + .addTransformOutput({ + s3OutputPath: 's3://outputbucket/prefix', + }) + .addTransformResources({ + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + }); + + const pub = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformJobTask(params) }); + + // THEN + expect(stack.node.resolve(pub.toStateJson())).toEqual({ + Type: 'Task', + Resource: 'arn:aws:states:::sagemaker:createTransformJob', + End: true, + Parameters: { + 'TransformJobName.$': '$.TransformJobName', + 'ModelName.$': '$.ModelName', + 'TransformInput': { + DataSource: { + S3DataSource: { + S3Uri: 's3://inputbucket/prefix', + S3DataType: 'S3Prefix', + } + } + }, + 'TransformOutput': { + S3OutputPath: 's3://outputbucket/prefix', + }, + 'TransformResources': { + InstanceCount: 1, + InstanceType: 'ml.p3.2xlarge', + } + }, + }); +}); + +test('throw error when mandatory parameter not found', () => { + // WHEN + const params = new tasks.TransformJobParameters("MyTransformJob", "MyModelName", role); + + // THEN + expect(() => params.toJson()).toThrow(); +}); \ No newline at end of file From 9424f3dd44b30f0c97c11c36987427ce31fb33bf Mon Sep 17 00:00:00 2001 From: Matt McClean Date: Sat, 8 Jun 2019 11:45:17 +0100 Subject: [PATCH 3/8] refactored Sagemaker tasks --- .../aws-stepfunctions-tasks/lib/index.ts | 6 +- ...-types.ts => sagemaker-task-base-types.ts} | 15 - .../lib/sagemaker-task-params.ts | 362 ------------------ .../lib/sagemaker-tasks.ts | 107 ------ .../lib/sagemaker-train-task.ts | 208 ++++++++++ .../lib/sagemaker-transform-task.ts | 177 +++++++++ .../test/sagemaker-training-job.test.ts | 250 ++++++------ .../test/sagemaker-transform-job.test.ts | 132 +++---- 8 files changed, 585 insertions(+), 672 deletions(-) rename packages/@aws-cdk/aws-stepfunctions-tasks/lib/{sagemaker-base-types.ts => sagemaker-task-base-types.ts} (95%) delete mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-params.ts delete mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-tasks.ts create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts create mode 100644 packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts index 3fbac70910c90..41f2533ba0149 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/index.ts @@ -6,6 +6,6 @@ export * from './publish-to-topic'; export * from './send-to-queue'; export * from './run-ecs-ec2-task'; export * from './run-ecs-fargate-task'; -export * from './sagemaker-base-types'; -export * from './sagemaker-task-params'; -export * from './sagemaker-tasks'; \ No newline at end of file +export * from './sagemaker-task-base-types'; +export * from './sagemaker-train-task'; +export * from './sagemaker-transform-task'; \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts similarity index 95% rename from packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-base-types.ts rename to packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts index bb3ea9eabd17b..236fd579f4902 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-base-types.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts @@ -5,9 +5,6 @@ import kms = require('@aws-cdk/aws-kms'); // Create Training Job types // -/** - * Identifies the training algorithm to use. - */ export interface AlgorithmSpecification { /** @@ -117,9 +114,6 @@ export interface S3DataSource { readonly s3Uri: string; } -/** - * Identifies the Amazon S3 location where you want Amazon SageMaker to save the results of model training. - */ export interface OutputDataConfig { /** * Optional KMS encryption key that Amazon SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side encryption. @@ -132,9 +126,6 @@ export interface OutputDataConfig { readonly s3OutputPath: string; } -/** - * Sets a time limit for training. - */ export interface StoppingCondition { /** * The maximum length of time, in seconds, that the training or compilation job can run. @@ -142,9 +133,6 @@ export interface StoppingCondition { readonly maxRuntimeInSeconds?: number; } -/** - * Identifies the resources, ML compute instances, and ML storage volumes to deploy for model training. - */ export interface ResourceConfig { /** @@ -168,9 +156,6 @@ export interface ResourceConfig { readonly volumeSizeInGB: number; } -/** - * Specifies the VPC that you want your training job to connect to. - */ export interface VpcConfig { /** * VPC security groups. diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-params.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-params.ts deleted file mode 100644 index 8d9058d1052b2..0000000000000 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-params.ts +++ /dev/null @@ -1,362 +0,0 @@ -import iam = require('@aws-cdk/aws-iam'); -import cdk = require('@aws-cdk/cdk'); -import { AlgorithmSpecification, BatchStrategy,Channel, OutputDataConfig, ResourceConfig, - TransformInput, TransformOutput, TransformResources, VpcConfig } from './sagemaker-base-types'; - -/** - * Parameters for the SageMaker Training Job. - */ -export class TrainingJobParameters extends cdk.Token { - - public trainingJobName: string; - public role: iam.Role; - private algorithmSpec: {[key: string]: any} = {}; - private hyperparameters: {[key: string]: any}; - private inputDataConfig = new Array(); - private outputDataConfig: {[key: string]: any} = {}; - private resourceConfig: {[key: string]: any} = {}; - private stoppingCondition: {[key: string]: any} = {}; - private tags = new Array(); - private vpcConfig: {[key: string]: any}; - - constructor(name: string, role: iam.Role) { - super(); - this.trainingJobName = name; - this.role = role; - } - - /** - * Adds the Alogorithm Specification config. - */ - public addAlgorithmSpec(spec: AlgorithmSpecification): TrainingJobParameters { - this.algorithmSpec = { - TrainingInputMode: spec.trainingInputMode, - ...(spec.trainingImage) ? { TrainingImage: spec.trainingImage } : {}, - ...(spec.algorithmName) ? { AlgorithmName: spec.algorithmName } : {}, - }; - if (spec.metricDefinitions) { - this.algorithmSpec.MetricDefinitions = []; - spec.metricDefinitions.forEach(metric => this.algorithmSpec.MetricDefinitions.push( { Name: metric.name, Regex: metric.regex } )); - } - return this; - } - - /** - * Add a hyperparameter to the config to the parameters. - */ - public addHyperparameter(key: string, value: string): TrainingJobParameters { - if (! this.hyperparameters) { - this.hyperparameters = {}; - } - this.hyperparameters[key] = value; - return this; - } - - /** - * Add multiple hyperparameters to the config to the parameters. - */ - public addHyperparameters(params: {[key: string]: any}): TrainingJobParameters { - Object.keys(params).map(key => { - this.addHyperparameter(key, params[key]); - }); - return this; - } - - /** - * Add an InputDataConfig to the list of config objects to the parameters. - */ - public addInputDataConfig(config: Channel): TrainingJobParameters { - this.inputDataConfig.push({ - ChannelName: config.channelName, - DataSource: { - S3DataSource: { - S3Uri: config.dataSource.s3DataSource.s3Uri, - S3DataType: config.dataSource.s3DataSource.s3DataType, - ...(config.dataSource.s3DataSource.s3DataDistributionType) ? - { S3DataDistributionType: config.dataSource.s3DataSource.s3DataDistributionType} : {}, - ...(config.dataSource.s3DataSource.attributeNames) ? { AtttributeNames: config.dataSource.s3DataSource.attributeNames } : {}, - } - }, - ...(config.compressionType) ? { CompressionType: config.compressionType } : {}, - ...(config.contentType) ? { ContentType: config.contentType } : {}, - ...(config.inputMode) ? { InputMode: config.inputMode } : {}, - ...(config.recordWrapperType) ? { RecordWrapperType: config.recordWrapperType } : {}, - }); - return this; - } - - /** - * Add a list of InputDataConfig objects to the parameters. - */ - public addInputDataConfigs(configs: Channel[]): TrainingJobParameters { - configs.forEach(config => this.addInputDataConfig(config)); - return this; - } - - /** - * Add a single tag pair to the parameters. - */ - public addTag(name: string, value: string): TrainingJobParameters { - this.tags.push({ Key: name, Value: value }); - return this; - } - - /** - * Add multiple tag pairs to the parameters. - */ - public addTags(tags: {[key: string]: any}): TrainingJobParameters { - Object.keys(tags).map(key => { - this.addTag(key, tags[key]); - }); - return this; - } - - /** - * Adds the Output Data Config to the parameters. - */ - public addOutputDataConfig(config: OutputDataConfig): TrainingJobParameters { - this.outputDataConfig = { - S3OutputPath: config.s3OutputPath, - ...(config.encryptionKey) ? { KmsKeyId: config.encryptionKey.keyArn } : {}, - }; - return this; - } - - /** - * Add a Resource Config to the parameters. - */ - public addResourceConfig(config: ResourceConfig): TrainingJobParameters { - this.resourceConfig = { - InstanceCount: config.instanceCount, - InstanceType: 'ml.' + config.instanceType, - VolumeSizeInGB: config.volumeSizeInGB, - ...(config.volumeKmsKeyId) ? { VolumeKmsKeyId: config.volumeKmsKeyId.keyArn } : {}, - }; - return this; - } - - /** - * Add a Stopping Condition to the parameters. - */ - public addStoppingCondition(maxRuntime: number): TrainingJobParameters { - this.stoppingCondition = { MaxRuntimeInSeconds: maxRuntime }; - return this; - } - - /** - * Add a VPC config to the parameters. - */ - public addVpcConfig(config: VpcConfig): TrainingJobParameters { - this.vpcConfig = { - SecurityGroupIds: [], - Subnets: [] - }; - config.securityGroups.forEach(sg => this.vpcConfig.SecurityGroupIds.push(sg.securityGroupId)); - config.subnets.forEach(subnet => this.vpcConfig.Subnets.push(subnet.subnetId)); - return this; - } - - // - // Serialization - // - public resolve(_context: cdk.IResolveContext): any { - return this.toJson(); - } - - public toJson(): any { - if (Object.entries(this.algorithmSpec).length === 0) { - throw new Error("Mandatory parameter 'AlgorithmSpecification' is empty"); - } - - if (Object.entries(this.inputDataConfig).length === 0) { - throw new Error("Mandatory parameter 'InputDataConfig' is empty"); - } - - if (Object.entries(this.outputDataConfig).length === 0) { - throw new Error("Mandatory parameter 'OutputDataConfig' is empty"); - } - - if (Object.entries(this.resourceConfig).length === 0) { - throw new Error("Mandatory parameter 'ResourceConfig' is empty"); - } - - if (Object.entries(this.stoppingCondition).length === 0) { - throw new Error("Mandatory parameter 'StoppingCondition' is empty"); - } - - return { - TrainingJobName: this.trainingJobName, - RoleArn: this.role.roleArn, - AlgorithmSpecification: this.algorithmSpec, - ...(this.hyperparameters) ? { HyperParameters: this.hyperparameters } : {}, - InputDataConfig: this.inputDataConfig, - OutputDataConfig: this.outputDataConfig, - ResourceConfig: this.resourceConfig, - StoppingCondition: this.stoppingCondition, - ...(this.tags.length > 0) ? { Tags: this.tags } : {}, - ...(this.vpcConfig) ? { VpcConfig: this.vpcConfig} : {}, - }; - } -} - -/** - * A class holding the SageMalker Transform Job parameters. - */ -export class TransformJobParameters extends cdk.Token { - - public transformJobName: string; - public role: iam.Role; - private batchStrategy: string; - private environmentVars: {[key: string]: any}; - private maxConcurrentTransforms: number; - private maxPayloadInMB: number; - private modelName: string; - private tags = new Array(); - private transformInput: {[key: string]: any} = {}; - private transformOutput: {[key: string]: any} = {}; - private transformResources: {[key: string]: any} = {}; - - constructor(jobName: string, modelName: string, role: iam.Role) { - super(); - this.transformJobName = jobName; - this.modelName = modelName; - this.role = role; - } - - /** - * Add a Batch strategy to the parameters. - */ - public addBatchStrategy(strategy: BatchStrategy): TransformJobParameters { - this.batchStrategy = strategy; - return this; - } - - /** - * Add an environment variable pair to the parameters. - */ - public addEnvironmentVar(key: string, value: string): TransformJobParameters { - if (! this.environmentVars) { - this.environmentVars = {}; - } - this.environmentVars[key] = value; - return this; - } - - /** - * Add multiple environment variable pairs to the parameters. - */ - public addEnvironmentVars(envars: {[key: string]: any}): TransformJobParameters { - Object.keys(envars).map(key => { - this.addEnvironmentVar(key, envars[key]); - }); - return this; - } - - /** - * Add a max concurrent transforms value to the parameters. - */ - public addMaxConcurrentTransforms(max: number): TransformJobParameters { - this.maxConcurrentTransforms = max; - return this; - } - - /** - * Add a max payload in MB to the parameters. - */ - public addMxaxPayloadInMB(max: number): TransformJobParameters { - this.maxPayloadInMB = max; - return this; - } - - /** - * Add an Transform Input config to the parameters. - */ - public addTransformInput(input: TransformInput): TransformJobParameters { - this.transformInput = { - DataSource: { - S3DataSource: { - S3Uri: input.transformDataSource.s3DataSource.s3Uri, - S3DataType: input.transformDataSource.s3DataSource.s3DataType, - } - }, - ...(input.compressionType) ? { CompressionType: input.compressionType } : {}, - ...(input.contentType) ? { ContentType: input.contentType } : {}, - ...(input.splitType) ? { SplitType: input.splitType } : {}, - }; - return this; - } - - /** - * Add a single tag pair to the parameters. - */ - public addTag(name: string, value: string): TransformJobParameters { - this.tags.push({ Key: name, Value: value }); - return this; - } - - /** - * Add multiple tag pairs - */ - public addTags(tags: {[key: string]: any}): TransformJobParameters { - Object.keys(tags).map(key => { - this.addTag(key, tags[key]); - }); - return this; - } - - /** - * Add a Transform Output config to the parameters. - */ - public addTransformOutput(output: TransformOutput): TransformJobParameters { - this.transformOutput = { - S3OutputPath: output.s3OutputPath, - ...(output.encryptionKey) ? { KmsKeyId: output.encryptionKey.keyArn } : {}, - ...(output.accept) ? { Accept: output.accept } : {}, - ...(output.assembleWith) ? { AssembleWith: output.assembleWith } : {}, - }; - return this; - } - - /** - * Add a Transform Resource config to the parameters. - */ - public addTransformResources(resource: TransformResources): TransformJobParameters { - this.transformResources = { - InstanceCount: resource.instanceCount, - InstanceType: 'ml.' + resource.instanceType, - ...(resource.volumeKmsKeyId) ? { VolumeKmsKeyId: resource.volumeKmsKeyId.keyArn } : {}, - }; - return this; - } - - public resolve(_context: cdk.IResolveContext): any { - return this.toJson(); - } - - public toJson(): any { - if (Object.entries(this.transformInput).length === 0) { - throw new Error("Mandatory parameter 'TransformInput' is empty"); - } - - if (Object.entries(this.transformOutput).length === 0) { - throw new Error("Mandatory parameter 'TransformOutput' is empty"); - } - - if (Object.entries(this.transformResources).length === 0) { - throw new Error("Mandatory parameter 'TransformResources' is empty"); - } - - return { - ...(this.batchStrategy) ? { BatchStrategy: this.batchStrategy} : {}, - ...(this.environmentVars) ? { Environment: this.environmentVars} : {}, - ...(this.maxConcurrentTransforms) ? { MaxConcurrentTransforms: this.maxConcurrentTransforms} : {}, - ...(this.maxPayloadInMB) ? { MaxPayloadInMB: this.maxPayloadInMB } : {}, - ModelName: this.modelName, - ...(this.tags.length > 0) ? { Tags: this.tags } : {}, - TransformInput: this.transformInput, - TransformJobName: this.transformJobName, - TransformOutput: this.transformOutput, - TransformResources: this.transformResources, - }; - } -} \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-tasks.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-tasks.ts deleted file mode 100644 index cb45e2681401e..0000000000000 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-tasks.ts +++ /dev/null @@ -1,107 +0,0 @@ -import ec2 = require('@aws-cdk/aws-ec2'); -import iam = require('@aws-cdk/aws-iam'); -import sfn = require('@aws-cdk/aws-stepfunctions'); - -import { TrainingJobParameters, TransformJobParameters } from './sagemaker-task-params'; - -/** - * Class representing the SageMaker Create Training Job task. - */ -export class SagemakerTrainingJobTask implements ec2.IConnectable, sfn.IStepFunctionsTask { - - public readonly connections: ec2.Connections = new ec2.Connections(); - - constructor(private readonly parameters: TrainingJobParameters, private readonly sync: boolean = false) {} - - public bind(task: sfn.Task): sfn.StepFunctionsTaskProperties { - return { - resourceArn: 'arn:aws:states:::sagemaker:createTrainingJob' + (this.sync ? '.sync' : ''), - parameters: sfn.FieldUtils.renderObject(this.parameters.toJson()), - policyStatements: this.makePolicyStatements(task), - }; - } - - private makePolicyStatements(task: sfn.Task): iam.PolicyStatement[] { - const stack = task.node.stack; - - // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html - const policyStatements = [ - new iam.PolicyStatement() - .addActions('sagemaker:CreateTrainingJob', 'sagemaker:DescribeTrainingJob', 'sagemaker:StopTrainingJob') - .addResource(stack.formatArn({ - service: 'sagemaker', - resource: 'training-job', - resourceName: '*' - })), - new iam.PolicyStatement() - .addAction('sagemaker:ListTags') - .addAllResources(), - new iam.PolicyStatement() - .addAction('iam:PassRole') - .addResources(this.parameters.role.roleArn) - .addCondition('StringEquals', { "iam:PassedToService": "sagemaker.amazonaws.com" }) - ]; - - if (this.sync) { - policyStatements.push(new iam.PolicyStatement() - .addActions("events:PutTargets", "events:PutRule", "events:DescribeRule") - .addResource(stack.formatArn({ - service: 'events', - resource: 'rule', - resourceName: 'StepFunctionsGetEventsForSageMakerTrainingJobsRule' - }))); - } - - return policyStatements; - } -} - -/** - * Class representing the SageMaker Create Transform Job task. - */ -export class SagemakerTransformJobTask implements sfn.IStepFunctionsTask { - - constructor(private readonly parameters: TransformJobParameters, private readonly sync: boolean = false) {} - - public bind(task: sfn.Task): sfn.StepFunctionsTaskProperties { - return { - resourceArn: 'arn:aws:states:::sagemaker:createTransformJob' + (this.sync ? '.sync' : ''), - parameters: sfn.FieldUtils.renderObject(this.parameters.toJson()), - policyStatements: this.makePolicyStatements(task), - }; - } - - private makePolicyStatements(task: sfn.Task): iam.PolicyStatement[] { - const stack = task.node.stack; - - // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html - const policyStatements = [ - new iam.PolicyStatement() - .addActions('sagemaker:CreateTransformJob', 'sagemaker:DescribeTransformJob', 'sagemaker:StopTransformJob') - .addResource(stack.formatArn({ - service: 'sagemaker', - resource: 'transform-job', - resourceName: '*' - })), - new iam.PolicyStatement() - .addAction('sagemaker:ListTags') - .addAllResources(), - new iam.PolicyStatement() - .addAction('iam:PassRole') - .addResources(this.parameters.role.roleArn) - .addCondition('StringEquals', { "iam:PassedToService": "sagemaker.amazonaws.com" }) - ]; - - if (this.sync) { - policyStatements.push(new iam.PolicyStatement() - .addActions("events:PutTargets", "events:PutRule", "events:DescribeRule") - .addResource(stack.formatArn({ - service: 'events', - resource: 'rule', - resourceName: 'StepFunctionsGetEventsForSageMakerTransformJobsRule' - }))); - } - - return policyStatements; - } -} \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts new file mode 100644 index 0000000000000..f307330d47d21 --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts @@ -0,0 +1,208 @@ +import ec2 = require('@aws-cdk/aws-ec2'); +import iam = require('@aws-cdk/aws-iam'); +import sfn = require('@aws-cdk/aws-stepfunctions'); + +import { AlgorithmSpecification, Channel, OutputDataConfig, ResourceConfig, StoppingCondition, VpcConfig } from './sagemaker-task-base-types'; + +export interface SagemakerTrainProps { + + /** + * Training Job Name. + */ + readonly trainingJobName: string; + + /** + * Role for thte Training Job. + */ + readonly role: iam.Role; + + /** + * Specify if the task is synchronous or asychronous. + */ + readonly synchronous?: boolean; + + /** + * Identifies the training algorithm to use. + */ + readonly algorithmSpecification: AlgorithmSpecification; + + /** + * Hyperparameters to be used for the train job. + */ + readonly hyperparameters?: {[key: string]: any}; + + /** + * Describes the various datasets (e.g. train, validation, test) and the Amazon S3 location where stored. + */ + readonly inputDataConfig: Channel[]; + + /** + * Tags to be applied to the train job. + */ + readonly tags?: {[key: string]: any}; + + /** + * Identifies the Amazon S3 location where you want Amazon SageMaker to save the results of model training. + */ + readonly outputDataConfig: OutputDataConfig; + + /** + * Identifies the resources, ML compute instances, and ML storage volumes to deploy for model training. + */ + readonly resourceConfig: ResourceConfig; + + /** + * Sets a time limit for training. + */ + readonly stoppingCondition: StoppingCondition; + + /** + * Specifies the VPC that you want your training job to connect to. + */ + readonly vpcConfig?: VpcConfig; +} + +/** + * Class representing the SageMaker Create Training Job task. + */ +export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsTask { + + public readonly connections: ec2.Connections = new ec2.Connections(); + + constructor(private readonly props: SagemakerTrainProps) { } + + public bind(task: sfn.Task): sfn.StepFunctionsTaskProperties { + return { + resourceArn: 'arn:aws:states:::sagemaker:createTrainingJob' + (this.props.synchronous ? '.sync' : ''), + parameters: sfn.FieldUtils.renderObject(this.renderParameters()), + policyStatements: this.makePolicyStatements(task), + }; + } + + private renderParameters(): {[key: string]: any} { + return { + TrainingJobName: this.props.trainingJobName, + RoleArn: this.props.role.roleArn, + ...(this.renderAlgorithmSpecification(this.props.algorithmSpecification)), + ...(this.renderInputDataConfig(this.props.inputDataConfig)), + ...(this.renderOutputDataConfig(this.props.outputDataConfig)), + ...(this.renderResourceConfig(this.props.resourceConfig)), + ...(this.renderStoppingCondition(this.props.stoppingCondition)), + ...(this.renderHyperparameters(this.props.hyperparameters)), + ...(this.renderTags(this.props.tags)), + ...(this.renderVpcConfig(this.props.vpcConfig)), + }; + } + + private renderAlgorithmSpecification(spec: AlgorithmSpecification): {[key: string]: any} { + return { + AlgorithmSpecification: { + TrainingInputMode: spec.trainingInputMode, + ...(spec.trainingImage) ? { TrainingImage: spec.trainingImage } : {}, + ...(spec.algorithmName) ? { AlgorithmName: spec.algorithmName } : {}, + ...(spec.metricDefinitions) ? + { MetricDefinitions: spec.metricDefinitions + .map(metric => ({ Name: metric.name, Regex: metric.regex })) } : {} + } + }; + } + + private renderInputDataConfig(config: Channel[]): {[key: string]: any} { + return { + InputDataConfig: config.map(channel => ({ + ChannelName: channel.channelName, + DataSource: { + S3DataSource: { + S3Uri: channel.dataSource.s3DataSource.s3Uri, + S3DataType: channel.dataSource.s3DataSource.s3DataType, + ...(channel.dataSource.s3DataSource.s3DataDistributionType) ? + { S3DataDistributionType: channel.dataSource.s3DataSource.s3DataDistributionType} : {}, + ...(channel.dataSource.s3DataSource.attributeNames) ? + { AtttributeNames: channel.dataSource.s3DataSource.attributeNames } : {}, + } + }, + ...(channel.compressionType) ? { CompressionType: channel.compressionType } : {}, + ...(channel.contentType) ? { ContentType: channel.contentType } : {}, + ...(channel.inputMode) ? { InputMode: channel.inputMode } : {}, + ...(channel.recordWrapperType) ? { RecordWrapperType: channel.recordWrapperType } : {}, + })) + }; + } + + private renderOutputDataConfig(config: OutputDataConfig): {[key: string]: any} { + return { + OutputDataConfig: { + S3OutputPath: config.s3OutputPath, + ...(config.encryptionKey) ? { KmsKeyId: config.encryptionKey.keyArn } : {}, + } + }; + } + + private renderResourceConfig(config: ResourceConfig): {[key: string]: any} { + return { + ResourceConfig: { + InstanceCount: config.instanceCount, + InstanceType: 'ml.' + config.instanceType, + VolumeSizeInGB: config.volumeSizeInGB, + ...(config.volumeKmsKeyId) ? { VolumeKmsKeyId: config.volumeKmsKeyId.keyArn } : {}, + } + }; + } + + private renderStoppingCondition(config: StoppingCondition): {[key: string]: any} { + return { + StoppingCondition: { + MaxRuntimeInSeconds: config.maxRuntimeInSeconds + } + }; + } + + private renderHyperparameters(params: {[key: string]: any} | undefined): {[key: string]: any} { + return (params) ? { HyperParameters: params } : {}; + } + + private renderTags(tags: {[key: string]: any} | undefined): {[key: string]: any} { + return (tags) ? { Tags: Object.keys(tags).map(key => ({ Key: key, Value: tags[key] })) } : {}; + } + + private renderVpcConfig(config: VpcConfig | undefined): {[key: string]: any} { + return (config) ? { VpcConfig: { + SecurityGroupIds: config.securityGroups.map(sg => ( sg.securityGroupId )), + Subnets: config.subnets.map(subnet => ( subnet.subnetId )), + }} : {}; + } + + private makePolicyStatements(task: sfn.Task): iam.PolicyStatement[] { + const stack = task.node.stack; + + // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html + const policyStatements = [ + new iam.PolicyStatement() + .addActions('sagemaker:CreateTrainingJob', 'sagemaker:DescribeTrainingJob', 'sagemaker:StopTrainingJob') + .addResource(stack.formatArn({ + service: 'sagemaker', + resource: 'training-job', + resourceName: '*' + })), + new iam.PolicyStatement() + .addAction('sagemaker:ListTags') + .addAllResources(), + new iam.PolicyStatement() + .addAction('iam:PassRole') + .addResources(this.props.role.roleArn) + .addCondition('StringEquals', { "iam:PassedToService": "sagemaker.amazonaws.com" }) + ]; + + if (this.props.synchronous) { + policyStatements.push(new iam.PolicyStatement() + .addActions("events:PutTargets", "events:PutRule", "events:DescribeRule") + .addResource(stack.formatArn({ + service: 'events', + resource: 'rule', + resourceName: 'StepFunctionsGetEventsForSageMakerTrainingJobsRule' + }))); + } + + return policyStatements; + } +} \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts new file mode 100644 index 0000000000000..3e5013395ef8c --- /dev/null +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts @@ -0,0 +1,177 @@ +import iam = require('@aws-cdk/aws-iam'); +import sfn = require('@aws-cdk/aws-stepfunctions'); + +import { BatchStrategy, TransformInput, TransformOutput, TransformResources } from './sagemaker-task-base-types'; + +export interface SagemakerTransformProps { + + /** + * Training Job Name. + */ + readonly transformJobName: string; + + /** + * Role for thte Training Job. + */ + readonly role: iam.Role; + + /** + * Specify if the task is synchronous or asychronous. + */ + readonly synchronous?: boolean; + + /** + * Number of records to include in a mini-batch for an HTTP inference request. + */ + readonly batchStrategy?: BatchStrategy; + + /** + * Environment variables to set in the Docker container. + */ + readonly environment?: {[key: string]: any}; + + /** + * Maximum number of parallel requests that can be sent to each instance in a transform job. + */ + readonly maxConcurrentTransforms?: number; + + /** + * Maximum allowed size of the payload, in MB. + */ + readonly maxPayloadInMB?: number; + + /** + * Name of the model that you want to use for the transform job. + */ + readonly modelName: string; + + /** + * Tags to be applied to the train job. + */ + readonly tags?: {[key: string]: any}; + + /** + * Dataset to be transformed and the Amazon S3 location where it is stored. + */ + readonly transformInput: TransformInput; + + /** + * S3 location where you want Amazon SageMaker to save the results from the transform job. + */ + readonly transformOutput: TransformOutput; + + /** + * ML compute instances for the transform job. + */ + readonly transformResources: TransformResources; +} + +/** + * Class representing the SageMaker Create Training Job task. + */ +export class SagemakerTransformTask implements sfn.IStepFunctionsTask { + + constructor(private readonly props: SagemakerTransformProps) { } + + public bind(task: sfn.Task): sfn.StepFunctionsTaskProperties { + return { + resourceArn: 'arn:aws:states:::sagemaker:createTransformJob' + (this.props.synchronous ? '.sync' : ''), + parameters: sfn.FieldUtils.renderObject(this.renderParameters()), + policyStatements: this.makePolicyStatements(task), + }; + } + + private renderParameters(): {[key: string]: any} { + return { + ...(this.props.batchStrategy) ? { BatchStrategy: this.props.batchStrategy } : {}, + ...(this.renderEnvironment(this.props.environment)), + ...(this.props.maxConcurrentTransforms) ? { MaxConcurrentTransforms: this.props.maxConcurrentTransforms } : {}, + ...(this.props.maxPayloadInMB) ? { MaxPayloadInMB: this.props.maxPayloadInMB } : {}, + ModelName: this.props.modelName, + ...(this.renderTags(this.props.tags)), + ...(this.renderTransformInput(this.props.transformInput)), + TransformJobName: this.props.transformJobName, + ...(this.renderTransformOutput(this.props.transformOutput)), + ...(this.renderTransformResources(this.props.transformResources)), + }; + } + + private renderTransformInput(input: TransformInput): {[key: string]: any} { + return { + TransformInput: { + ...(input.compressionType) ? { CompressionType: input.compressionType } : {}, + ...(input.contentType) ? { ContentType: input.contentType } : {}, + DataSource: { + S3DataSource: { + S3Uri: input.transformDataSource.s3DataSource.s3Uri, + S3DataType: input.transformDataSource.s3DataSource.s3DataType, + } + }, + ...(input.splitType) ? { SplitType: input.splitType } : {}, + } + }; + } + + private renderTransformOutput(output: TransformOutput): {[key: string]: any} { + return { + TransformOutput: { + S3OutputPath: output.s3OutputPath, + ...(output.encryptionKey) ? { KmsKeyId: output.encryptionKey.keyArn } : {}, + ...(output.accept) ? { Accept: output.accept } : {}, + ...(output.assembleWith) ? { AssembleWith: output.assembleWith } : {}, + } + }; + } + + private renderTransformResources(resources: TransformResources): {[key: string]: any} { + return { + TransformResources: { + InstanceCount: resources.instanceCount, + InstanceType: 'ml.' + resources.instanceType, + ...(resources.volumeKmsKeyId) ? { VolumeKmsKeyId: resources.volumeKmsKeyId.keyArn } : {}, + } + }; + } + + private renderEnvironment(environment: {[key: string]: any} | undefined): {[key: string]: any} { + return (environment) ? { Environment: environment } : {}; + } + + private renderTags(tags: {[key: string]: any} | undefined): {[key: string]: any} { + return (tags) ? { Tags: Object.keys(tags).map(key => ({ Key: key, Value: tags[key] })) } : {}; + } + + private makePolicyStatements(task: sfn.Task): iam.PolicyStatement[] { + const stack = task.node.stack; + + // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html + const policyStatements = [ + new iam.PolicyStatement() + .addActions('sagemaker:CreateTransformJob', 'sagemaker:DescribeTransformJob', 'sagemaker:StopTransformJob') + .addResource(stack.formatArn({ + service: 'sagemaker', + resource: 'transform-job', + resourceName: '*' + })), + new iam.PolicyStatement() + .addAction('sagemaker:ListTags') + .addAllResources(), + new iam.PolicyStatement() + .addAction('iam:PassRole') + .addResources(this.props.role.roleArn) + .addCondition('StringEquals', { "iam:PassedToService": "sagemaker.amazonaws.com" }) + ]; + + if (this.props.synchronous) { + policyStatements.push(new iam.PolicyStatement() + .addActions("events:PutTargets", "events:PutRule", "events:DescribeRule") + .addResource(stack.formatArn({ + service: 'events', + resource: 'rule', + resourceName: 'StepFunctionsGetEventsForSageMakerTransformJobsRule' + }))); + } + + return policyStatements; + } +} \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts index 628d8a40390b6..8e4f382c751bf 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts @@ -22,31 +22,39 @@ beforeEach(() => { test('create basic training job', () => { // WHEN - const params = new tasks.TrainingJobParameters("MyTrainJob", role); - params.addAlgorithmSpec({ algorithmName: "BlazingText", trainingInputMode: tasks.InputMode.File}) - .addInputDataConfig( - { - channelName: "train", - dataSource: { - s3DataSource: { - s3DataType: tasks.S3DataType.S3Prefix, - s3Uri: "s3://mybucket/mytrainpath" - } + const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask({ + trainingJobName: "MyTrainJob", + role, + algorithmSpecification: { + algorithmName: "BlazingText", + trainingInputMode: tasks.InputMode.File + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3Prefix, + s3Uri: "s3://mybucket/mytrainpath" } - }) - .addOutputDataConfig({ s3OutputPath: 's3://mybucket/myoutputpath' }) - .addResourceConfig( - { - instanceCount: 1, - instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), - volumeSizeInGB: 50 - }) - .addStoppingCondition(3600); - - const pub = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainingJobTask(params) }); + } + } + ], + outputDataConfig: { + s3OutputPath: 's3://mybucket/myoutputpath' + }, + resourceConfig: { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + volumeSizeInGB: 50 + }, + stoppingCondition: { + maxRuntimeInSeconds: 3600 + } + })}); // THEN - expect(stack.node.resolve(pub.toStateJson())).toEqual({ + expect(stack.node.resolve(task.toStateJson())).toEqual({ Type: 'Task', Resource: 'arn:aws:states:::sagemaker:createTrainingJob', End: true, @@ -90,60 +98,73 @@ test('create complex training job', () => { const securityGroup = new ec2.SecurityGroup(stack, 'SecurityGroup', { vpc, description: 'My SG' }); securityGroup.addIngressRule(new ec2.AnyIPv4(), new ec2.TcpPort(22), 'allow ssh access from the world'); - const params = new tasks.TrainingJobParameters("MyTrainJob", role); - params.addAlgorithmSpec( - { - algorithmName: "BlazingText", - trainingInputMode: tasks.InputMode.File, - metricDefinitions: [ - { - name: 'mymetric', regex: 'regex_pattern' - } - ] - }) - .addHyperparameter("lr", "0.1" ) - .addInputDataConfigs([ + const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask({ + trainingJobName: "MyTrainJob", + synchronous: true, + role, + algorithmSpecification: { + algorithmName: "BlazingText", + trainingInputMode: tasks.InputMode.File, + metricDefinitions: [ { - channelName: "train", - contentType: "image/jpeg", - compressionType: tasks.CompressionType.None, - recordWrapperType: tasks.RecordWrapperType.RecordIO, - dataSource: { - s3DataSource: { - s3DataType: tasks.S3DataType.S3Prefix, - s3Uri: "s3://mybucket/mytrainpath", - } + name: 'mymetric', regex: 'regex_pattern' + } + ] + }, + hyperparameters: { + lr: "0.1" + }, + inputDataConfig: [ + { + channelName: "train", + contentType: "image/jpeg", + compressionType: tasks.CompressionType.None, + recordWrapperType: tasks.RecordWrapperType.RecordIO, + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3Prefix, + s3Uri: "s3://mybucket/mytrainpath", } - }, - { - channelName: "test", - contentType: "image/jpeg", - compressionType: tasks.CompressionType.Gzip, - recordWrapperType: tasks.RecordWrapperType.RecordIO, - dataSource: { - s3DataSource: { - s3DataType: tasks.S3DataType.S3Prefix, - s3Uri: "s3://mybucket/mytestpath", - } + } + }, + { + channelName: "test", + contentType: "image/jpeg", + compressionType: tasks.CompressionType.Gzip, + recordWrapperType: tasks.RecordWrapperType.RecordIO, + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3Prefix, + s3Uri: "s3://mybucket/mytestpath", } } - ]) - .addOutputDataConfig({ s3OutputPath: 's3://mybucket/myoutputpath', encryptionKey: kmsKey }) - .addResourceConfig( - { - instanceCount: 1, - instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), - volumeSizeInGB: 50, - volumeKmsKeyId: kmsKey, - }) - .addStoppingCondition(3600) - .addTag("Project", "MyProject") - .addVpcConfig({ vpc, subnets: vpc.privateSubnets, securityGroups: [ securityGroup ] }); - - const pub = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainingJobTask(params, true) }); + } + ], + outputDataConfig: { + s3OutputPath: 's3://mybucket/myoutputpath', + encryptionKey: kmsKey + }, + resourceConfig: { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + volumeSizeInGB: 50, + volumeKmsKeyId: kmsKey, + }, + stoppingCondition: { + maxRuntimeInSeconds: 3600 + }, + tags: { + Project: "MyProject" + }, + vpcConfig: { + vpc, + subnets: vpc.privateSubnets, + securityGroups: [ securityGroup ] + } + })}); // THEN - expect(stack.node.resolve(pub.toStateJson())).toEqual({ + expect(stack.node.resolve(task.toStateJson())).toEqual({ Type: 'Task', Resource: 'arn:aws:states:::sagemaker:createTrainingJob.sync', End: true, @@ -216,31 +237,39 @@ test('create complex training job', () => { test('pass param to training job', () => { // WHEN - const params = new tasks.TrainingJobParameters(sfn.Data.stringAt('$.JobName'), role); - params.addAlgorithmSpec({ algorithmName: "BlazingText", trainingInputMode: tasks.InputMode.File}) - .addInputDataConfig( - { - channelName: "train", - dataSource: { - s3DataSource: { - s3DataType: tasks.S3DataType.S3Prefix, - s3Uri: sfn.Data.stringAt('$.S3Bucket') - } + const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask({ + trainingJobName: sfn.Data.stringAt('$.JobName'), + role, + algorithmSpecification: { + algorithmName: "BlazingText", + trainingInputMode: tasks.InputMode.File + }, + inputDataConfig: [ + { + channelName: 'train', + dataSource: { + s3DataSource: { + s3DataType: tasks.S3DataType.S3Prefix, + s3Uri: sfn.Data.stringAt('$.S3Bucket') } - }) - .addOutputDataConfig({ s3OutputPath: 's3://mybucket/myoutputpath' }) - .addResourceConfig( - { - instanceCount: 1, - instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), - volumeSizeInGB: 50 - }) - .addStoppingCondition(3600); - - const pub = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainingJobTask(params) }); + } + } + ], + outputDataConfig: { + s3OutputPath: 's3://mybucket/myoutputpath' + }, + resourceConfig: { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + volumeSizeInGB: 50 + }, + stoppingCondition: { + maxRuntimeInSeconds: 3600 + } + })}); // THEN - expect(stack.node.resolve(pub.toStateJson())).toEqual({ + expect(stack.node.resolve(task.toStateJson())).toEqual({ Type: 'Task', Resource: 'arn:aws:states:::sagemaker:createTrainingJob', End: true, @@ -248,14 +277,14 @@ test('pass param to training job', () => { 'TrainingJobName.$': '$.JobName', 'RoleArn': { "Fn::GetAtt": [ "Role1ABCC5F0", "Arn" ] }, 'AlgorithmSpecification': { - 'TrainingInputMode': 'File', - 'AlgorithmName': 'BlazingText', + TrainingInputMode: 'File', + AlgorithmName: 'BlazingText', }, 'InputDataConfig': [ { - 'ChannelName': 'train', - 'DataSource': { - 'S3DataSource': { + ChannelName: 'train', + DataSource: { + S3DataSource: { 'S3DataType': 'S3Prefix', 'S3Uri.$': '$.S3Bucket' } @@ -263,33 +292,16 @@ test('pass param to training job', () => { } ], 'OutputDataConfig': { - 'S3OutputPath': 's3://mybucket/myoutputpath' + S3OutputPath: 's3://mybucket/myoutputpath' }, 'ResourceConfig': { - 'InstanceCount': 1, - 'InstanceType': 'ml.p3.2xlarge', - 'VolumeSizeInGB': 50 + InstanceCount: 1, + InstanceType: 'ml.p3.2xlarge', + VolumeSizeInGB: 50 }, 'StoppingCondition': { - 'MaxRuntimeInSeconds': 3600 + MaxRuntimeInSeconds: 3600 } }, }); -}); - -test('throw error when mandatory parameter not found', () => { - // WHEN - const params = new tasks.TrainingJobParameters(sfn.Data.stringAt('$.JobName'), role); - params.addAlgorithmSpec({ algorithmName: "BlazingText", trainingInputMode: tasks.InputMode.File}) - .addOutputDataConfig({ s3OutputPath: 's3://mybucket/myoutputpath' }) - .addResourceConfig( - { - instanceCount: 1, - instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), - volumeSizeInGB: 50 - }) - .addStoppingCondition(3600); - - // THEN - expect(() => params.toJson()).toThrow(); }); \ No newline at end of file diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts index abca2727454d2..1c74bcf810a1b 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts @@ -5,7 +5,7 @@ import kms = require('@aws-cdk/aws-kms'); import sfn = require('@aws-cdk/aws-stepfunctions'); import cdk = require('@aws-cdk/cdk'); import tasks = require('../lib'); -import { S3DataType, BatchStrategy } from '../lib'; +import { BatchStrategy, S3DataType } from '../lib'; let stack: cdk.Stack; let role: iam.Role; @@ -23,28 +23,29 @@ beforeEach(() => { test('create basic transform job', () => { // WHEN - const params = new tasks.TransformJobParameters("MyTransformJob", "MyModelName", role); - params.addTransformInput( - { - transformDataSource: { - s3DataSource: { - s3Uri: 's3://inputbucket/prefix', - s3DataType: S3DataType.S3Prefix, - } + const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ + transformJobName: "MyTransformJob", + modelName: "MyModelName", + role, + transformInput: { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/prefix', + s3DataType: S3DataType.S3Prefix, } - }) - .addTransformOutput({ - s3OutputPath: 's3://outputbucket/prefix', - }) - .addTransformResources({ - instanceCount: 1, - instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), - }); - - const pub = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformJobTask(params) }); + } + }, + transformOutput: { + s3OutputPath: 's3://outputbucket/prefix', + }, + transformResources: { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + } + }) }); // THEN - expect(stack.node.resolve(pub.toStateJson())).toEqual({ + expect(stack.node.resolve(task.toStateJson())).toEqual({ Type: 'Task', Resource: 'arn:aws:states:::sagemaker:createTransformJob', End: true, @@ -73,35 +74,41 @@ test('create basic transform job', () => { test('create complex transform job', () => { // WHEN const kmsKey = new kms.Key(stack, 'Key'); - const params = new tasks.TransformJobParameters("MyTransformJob", "MyModelName", role); - params.addTransformInput( - { - transformDataSource: { - s3DataSource: { - s3Uri: 's3://inputbucket/prefix', - s3DataType: S3DataType.S3Prefix, - } + const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ + transformJobName: "MyTransformJob", + modelName: "MyModelName", + synchronous: true, + role, + transformInput: { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/prefix', + s3DataType: S3DataType.S3Prefix, } - }) - .addTransformOutput({ - s3OutputPath: 's3://outputbucket/prefix', - encryptionKey: kmsKey, - }) - .addTransformResources({ - instanceCount: 1, - instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), - volumeKmsKeyId: kmsKey, - }) - .addTag('Project', 'MyProject') - .addBatchStrategy(BatchStrategy.MultiRecord) - .addEnvironmentVar('SOMEVAR', 'myvalue') - .addMaxConcurrentTransforms(3) - .addMxaxPayloadInMB(100); - - const pub = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformJobTask(params, true) }); + } + }, + transformOutput: { + s3OutputPath: 's3://outputbucket/prefix', + encryptionKey: kmsKey, + }, + transformResources: { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), + volumeKmsKeyId: kmsKey, + }, + tags: { + Project: 'MyProject', + }, + batchStrategy: BatchStrategy.MultiRecord, + environment: { + SOMEVAR: 'myvalue' + }, + maxConcurrentTransforms: 3, + maxPayloadInMB: 100, + }) }); // THEN - expect(stack.node.resolve(pub.toStateJson())).toEqual({ + expect(stack.node.resolve(task.toStateJson())).toEqual({ Type: 'Task', Resource: 'arn:aws:states:::sagemaker:createTransformJob.sync', End: true, @@ -140,28 +147,29 @@ test('create complex transform job', () => { test('pass param to transform job', () => { // WHEN - const params = new tasks.TransformJobParameters(sfn.Data.stringAt('$.TransformJobName'), sfn.Data.stringAt('$.ModelName'), role); - params.addTransformInput( - { + const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ + transformJobName: sfn.Data.stringAt('$.TransformJobName'), + modelName: sfn.Data.stringAt('$.ModelName'), + role, + transformInput: { transformDataSource: { s3DataSource: { s3Uri: 's3://inputbucket/prefix', s3DataType: S3DataType.S3Prefix, } } - }) - .addTransformOutput({ - s3OutputPath: 's3://outputbucket/prefix', - }) - .addTransformResources({ + }, + transformOutput: { + s3OutputPath: 's3://outputbucket/prefix', + }, + transformResources: { instanceCount: 1, instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), - }); - - const pub = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformJobTask(params) }); + } + }) }); // THEN - expect(stack.node.resolve(pub.toStateJson())).toEqual({ + expect(stack.node.resolve(task.toStateJson())).toEqual({ Type: 'Task', Resource: 'arn:aws:states:::sagemaker:createTransformJob', End: true, @@ -185,12 +193,4 @@ test('pass param to transform job', () => { } }, }); -}); - -test('throw error when mandatory parameter not found', () => { - // WHEN - const params = new tasks.TransformJobParameters("MyTransformJob", "MyModelName", role); - - // THEN - expect(() => params.toJson()).toThrow(); }); \ No newline at end of file From a4e58e2e2113577d7f39de1dd6c0fc5994250816 Mon Sep 17 00:00:00 2001 From: Matt McClean Date: Sat, 8 Jun 2019 12:10:35 +0100 Subject: [PATCH 4/8] fixed breaking changes to stack and sfn objects --- .../aws-stepfunctions-tasks/lib/sagemaker-train-task.ts | 6 +++--- .../aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts | 6 +++--- .../test/sagemaker-training-job.test.ts | 6 +++--- .../test/sagemaker-transform-job.test.ts | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts index f307330d47d21..1a68206fd7014 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts @@ -1,7 +1,7 @@ import ec2 = require('@aws-cdk/aws-ec2'); import iam = require('@aws-cdk/aws-iam'); import sfn = require('@aws-cdk/aws-stepfunctions'); - +import { Stack } from '@aws-cdk/cdk'; import { AlgorithmSpecification, Channel, OutputDataConfig, ResourceConfig, StoppingCondition, VpcConfig } from './sagemaker-task-base-types'; export interface SagemakerTrainProps { @@ -71,7 +71,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT constructor(private readonly props: SagemakerTrainProps) { } - public bind(task: sfn.Task): sfn.StepFunctionsTaskProperties { + public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig { return { resourceArn: 'arn:aws:states:::sagemaker:createTrainingJob' + (this.props.synchronous ? '.sync' : ''), parameters: sfn.FieldUtils.renderObject(this.renderParameters()), @@ -173,7 +173,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT } private makePolicyStatements(task: sfn.Task): iam.PolicyStatement[] { - const stack = task.node.stack; + const stack = Stack.of(task); // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html const policyStatements = [ diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts index 3e5013395ef8c..9ff5848739ae9 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts @@ -1,6 +1,6 @@ import iam = require('@aws-cdk/aws-iam'); import sfn = require('@aws-cdk/aws-stepfunctions'); - +import { Stack } from '@aws-cdk/cdk'; import { BatchStrategy, TransformInput, TransformOutput, TransformResources } from './sagemaker-task-base-types'; export interface SagemakerTransformProps { @@ -73,7 +73,7 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask { constructor(private readonly props: SagemakerTransformProps) { } - public bind(task: sfn.Task): sfn.StepFunctionsTaskProperties { + public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig { return { resourceArn: 'arn:aws:states:::sagemaker:createTransformJob' + (this.props.synchronous ? '.sync' : ''), parameters: sfn.FieldUtils.renderObject(this.renderParameters()), @@ -142,7 +142,7 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask { } private makePolicyStatements(task: sfn.Task): iam.PolicyStatement[] { - const stack = task.node.stack; + const stack = Stack.of(task); // https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html const policyStatements = [ diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts index 8e4f382c751bf..a795a3c15f473 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts @@ -54,7 +54,7 @@ test('create basic training job', () => { })}); // THEN - expect(stack.node.resolve(task.toStateJson())).toEqual({ + expect(stack.resolve(task.toStateJson())).toEqual({ Type: 'Task', Resource: 'arn:aws:states:::sagemaker:createTrainingJob', End: true, @@ -164,7 +164,7 @@ test('create complex training job', () => { })}); // THEN - expect(stack.node.resolve(task.toStateJson())).toEqual({ + expect(stack.resolve(task.toStateJson())).toEqual({ Type: 'Task', Resource: 'arn:aws:states:::sagemaker:createTrainingJob.sync', End: true, @@ -269,7 +269,7 @@ test('pass param to training job', () => { })}); // THEN - expect(stack.node.resolve(task.toStateJson())).toEqual({ + expect(stack.resolve(task.toStateJson())).toEqual({ Type: 'Task', Resource: 'arn:aws:states:::sagemaker:createTrainingJob', End: true, diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts index 1c74bcf810a1b..95dc160ff3633 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts @@ -45,7 +45,7 @@ test('create basic transform job', () => { }) }); // THEN - expect(stack.node.resolve(task.toStateJson())).toEqual({ + expect(stack.resolve(task.toStateJson())).toEqual({ Type: 'Task', Resource: 'arn:aws:states:::sagemaker:createTransformJob', End: true, @@ -108,7 +108,7 @@ test('create complex transform job', () => { }) }); // THEN - expect(stack.node.resolve(task.toStateJson())).toEqual({ + expect(stack.resolve(task.toStateJson())).toEqual({ Type: 'Task', Resource: 'arn:aws:states:::sagemaker:createTransformJob.sync', End: true, @@ -169,7 +169,7 @@ test('pass param to transform job', () => { }) }); // THEN - expect(stack.node.resolve(task.toStateJson())).toEqual({ + expect(stack.resolve(task.toStateJson())).toEqual({ Type: 'Task', Resource: 'arn:aws:states:::sagemaker:createTransformJob', End: true, From 5fffde6d30b3e84ce303f70f1553e4cb40227ec2 Mon Sep 17 00:00:00 2001 From: Matt McClean Date: Mon, 10 Jun 2019 13:23:27 +0100 Subject: [PATCH 5/8] updated step functions readme --- packages/@aws-cdk/aws-stepfunctions/README.md | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/packages/@aws-cdk/aws-stepfunctions/README.md b/packages/@aws-cdk/aws-stepfunctions/README.md index 0bb4278c64659..6cb6f99c8720b 100644 --- a/packages/@aws-cdk/aws-stepfunctions/README.md +++ b/packages/@aws-cdk/aws-stepfunctions/README.md @@ -126,6 +126,8 @@ couple of the tasks available are: * `tasks.SendToQueue` -- send a message to an SQS queue * `tasks.RunEcsFargateTask`/`ecs.RunEcsEc2Task` -- run a container task, depending on the type of capacity. +* `tasks.SagemakerTrainTask` -- run a SageMaker training job +* `tasks.SagemakerTransformTask` -- run a SageMaker transform job #### Task parameters from the state json @@ -249,6 +251,34 @@ const task = new sfn.Task(this, 'CallFargate', { }); ``` +#### SgaeMaker Transform example + +```ts +const transformJob = new tasks.SagemakerTransformTask( + transformJobName: "MyTransformJob", + modelName: "MyModelName", + role, + transformInput: { + transformDataSource: { + s3DataSource: { + s3Uri: 's3://inputbucket/train', + s3DataType: S3DataType.S3Prefix, + } + } + }, + transformOutput: { + s3OutputPath: 's3://outputbucket/TransformJobOutputPath', + }, + transformResources: { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.M4, ec2.InstanceSize.XLarge), +}); + +const task = new sfn.Task(this, 'Batch Inference', { + task: transformJob +}); +``` + ### Pass A `Pass` state does no work, but it can optionally transform the execution's From 7e26c30e2fe468dfad65ad3a74640dbc3082ac4e Mon Sep 17 00:00:00 2001 From: Matt McClean Date: Mon, 10 Jun 2019 14:22:44 +0100 Subject: [PATCH 6/8] fixed lint problems --- .../lib/sagemaker-task-base-types.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts index 236fd579f4902..00a9249a7ee0d 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts @@ -254,8 +254,8 @@ export enum InputMode { File = 'File' } -/**Compression type of the data. - * +/** + * Compression type of the data. */ export enum CompressionType { /** @@ -294,7 +294,7 @@ export interface TransformInput { readonly transformDataSource: TransformDataSource; /** - * + * Method to use to split the transform job's data files into smaller batches. */ readonly splitType?: SplitType; } @@ -411,7 +411,7 @@ export enum SplitType { /** * Split using TensorFlow TFRecord format. - */ + */ TFRecord = 'TFRecord' } From 7057eeef3a0f885cc08e7a97c813c7e1806e33d4 Mon Sep 17 00:00:00 2001 From: Matt McClean Date: Mon, 10 Jun 2019 15:06:17 +0100 Subject: [PATCH 7/8] fixed lint errors --- .../aws-stepfunctions-tasks/lib/sagemaker-train-task.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts index 1a68206fd7014..fbd4667a22b8e 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts @@ -52,7 +52,7 @@ export interface SagemakerTrainProps { readonly resourceConfig: ResourceConfig; /** - * Sets a time limit for training. + * Sets a time limit for training. */ readonly stoppingCondition: StoppingCondition; @@ -166,7 +166,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT } private renderVpcConfig(config: VpcConfig | undefined): {[key: string]: any} { - return (config) ? { VpcConfig: { + return (config) ? { VpcConfig: { SecurityGroupIds: config.securityGroups.map(sg => ( sg.securityGroupId )), Subnets: config.subnets.map(subnet => ( subnet.subnetId )), }} : {}; From fb0bcbfb9c3b127a9d3229025a5d64d500c4de9e Mon Sep 17 00:00:00 2001 From: Matt McClean Date: Wed, 12 Jun 2019 11:19:10 +0100 Subject: [PATCH 8/8] updated with new default params and experimental tag --- .../lib/sagemaker-task-base-types.ts | 46 +++++++- .../lib/sagemaker-train-task.ts | 101 +++++++++++++++--- .../lib/sagemaker-transform-task.ts | 65 +++++++++-- .../test/sagemaker-training-job.test.ts | 50 ++++----- .../test/sagemaker-transform-job.test.ts | 14 +-- 5 files changed, 216 insertions(+), 60 deletions(-) diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts index 00a9249a7ee0d..0a8e6bf365795 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts @@ -5,6 +5,9 @@ import kms = require('@aws-cdk/aws-kms'); // Create Training Job types // +/** + * @experimental + */ export interface AlgorithmSpecification { /** @@ -24,12 +27,16 @@ export interface AlgorithmSpecification { /** * Input mode that the algorithm supports. + * + * @default is 'File' mode */ - readonly trainingInputMode: InputMode; + readonly trainingInputMode?: InputMode; } /** * Describes the training, validation or test dataset and the Amazon S3 location where it is stored. + * + * @experimental */ export interface Channel { @@ -71,6 +78,8 @@ export interface Channel { /** * Configuration for a shuffle option for input data in a channel. + * + * @experimental */ export interface ShuffleConfig { /** @@ -81,6 +90,8 @@ export interface ShuffleConfig { /** * Location of the channel data. + * + * @experimental */ export interface DataSource { /** @@ -91,6 +102,8 @@ export interface DataSource { /** * S3 location of the channel data. + * + * @experimental */ export interface S3DataSource { /** @@ -106,7 +119,7 @@ export interface S3DataSource { /** * S3 Data Type */ - readonly s3DataType: S3DataType; + readonly s3DataType?: S3DataType; /** * S3 Uri @@ -114,6 +127,9 @@ export interface S3DataSource { readonly s3Uri: string; } +/** + * @experimental + */ export interface OutputDataConfig { /** * Optional KMS encryption key that Amazon SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side encryption. @@ -137,11 +153,15 @@ export interface ResourceConfig { /** * The number of ML compute instances to use. + * + * @default 1 instance. */ readonly instanceCount: number; /** * ML compute instance type. + * + * @default is the 'm4.xlarge' instance type. */ readonly instanceType: ec2.InstanceType; @@ -152,10 +172,16 @@ export interface ResourceConfig { /** * Size of the ML storage volume that you want to provision. + * + * @default 10 GB EBS volume. */ readonly volumeSizeInGB: number; } +/** + * + * @experimental + */ export interface VpcConfig { /** * VPC security groups. @@ -175,6 +201,8 @@ export interface VpcConfig { /** * Specifies the metric name and regular expressions used to parse algorithm logs. + * + * @experimental */ export interface MetricDefinition { @@ -275,6 +303,8 @@ export enum CompressionType { /** * Dataset to be transformed and the Amazon S3 location where it is stored. + * + * @experimental */ export interface TransformInput { @@ -301,6 +331,8 @@ export interface TransformInput { /** * S3 location of the input data that the model can consume. + * + * @experimental */ export interface TransformDataSource { @@ -312,13 +344,17 @@ export interface TransformDataSource { /** * Location of the channel data. + * + * @experimental */ export interface TransformS3DataSource { /** * S3 Data Type + * + * @default 'S3Prefix' */ - readonly s3DataType: S3DataType; + readonly s3DataType?: S3DataType; /** * Identifies either a key name prefix or a manifest. @@ -328,6 +364,8 @@ export interface TransformS3DataSource { /** * S3 location where you want Amazon SageMaker to save the results from the transform job. + * + * @experimental */ export interface TransformOutput { @@ -354,6 +392,8 @@ export interface TransformOutput { /** * ML compute instances for the transform job. + * + * @experimental */ export interface TransformResources { diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts index fbd4667a22b8e..9d173384cdf0e 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts @@ -1,9 +1,13 @@ import ec2 = require('@aws-cdk/aws-ec2'); import iam = require('@aws-cdk/aws-iam'); import sfn = require('@aws-cdk/aws-stepfunctions'); -import { Stack } from '@aws-cdk/cdk'; -import { AlgorithmSpecification, Channel, OutputDataConfig, ResourceConfig, StoppingCondition, VpcConfig } from './sagemaker-task-base-types'; +import { Construct, Stack } from '@aws-cdk/cdk'; +import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig, + S3DataType, StoppingCondition, VpcConfig, } from './sagemaker-task-base-types'; +/** + * @experimental + */ export interface SagemakerTrainProps { /** @@ -14,10 +18,12 @@ export interface SagemakerTrainProps { /** * Role for thte Training Job. */ - readonly role: iam.Role; + readonly role?: iam.IRole; /** * Specify if the task is synchronous or asychronous. + * + * @default false */ readonly synchronous?: boolean; @@ -49,12 +55,12 @@ export interface SagemakerTrainProps { /** * Identifies the resources, ML compute instances, and ML storage volumes to deploy for model training. */ - readonly resourceConfig: ResourceConfig; + readonly resourceConfig?: ResourceConfig; /** * Sets a time limit for training. */ - readonly stoppingCondition: StoppingCondition; + readonly stoppingCondition?: StoppingCondition; /** * Specifies the VPC that you want your training job to connect to. @@ -67,9 +73,80 @@ export interface SagemakerTrainProps { */ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsTask { + /** + * Allows specify security group connections for instances of this fleet. + */ public readonly connections: ec2.Connections = new ec2.Connections(); - constructor(private readonly props: SagemakerTrainProps) { } + /** + * The execution role for the Sagemaker training job. + * + * @default new role for Amazon SageMaker to assume is automatically created. + */ + public readonly role: iam.IRole; + + /** + * The Algorithm Specification + */ + private readonly algorithmSpecification: AlgorithmSpecification; + + /** + * The Input Data Config. + */ + private readonly inputDataConfig: Channel[]; + + /** + * The resource config for the task. + */ + private readonly resourceConfig: ResourceConfig; + + /** + * The resource config for the task. + */ + private readonly stoppingCondition: StoppingCondition; + + constructor(scope: Construct, private readonly props: SagemakerTrainProps) { + + // set the default resource config if not defined. + this.resourceConfig = props.resourceConfig || { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.M4, ec2.InstanceSize.XLarge), + volumeSizeInGB: 10 + }; + + // set the stopping condition if not defined + this.stoppingCondition = props.stoppingCondition || { + maxRuntimeInSeconds: 3600 + }; + + // set the sagemaker role or create new one + this.role = props.role || new iam.Role(scope, 'SagemakerRole', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicyArns: [ + new iam.AwsManagedPolicy('AmazonSageMakerFullAccess', scope).policyArn + ] + }); + + // set the input mode to 'File' if not defined + this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ? + ( props.algorithmSpecification ) : + ( { ...props.algorithmSpecification, trainingInputMode: InputMode.File } ); + + // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined + this.inputDataConfig = props.inputDataConfig.map(config => { + if (!config.dataSource.s3DataSource.s3DataType) { + return Object.assign({}, config, { dataSource: { s3DataSource: + { ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3Prefix } } }); + } else { + return config; + } + }); + + // add the security groups to the connections object + if (this.props.vpcConfig) { + this.props.vpcConfig.securityGroups.forEach(sg => this.connections.addSecurityGroup(sg)); + } + } public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig { return { @@ -82,12 +159,12 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT private renderParameters(): {[key: string]: any} { return { TrainingJobName: this.props.trainingJobName, - RoleArn: this.props.role.roleArn, - ...(this.renderAlgorithmSpecification(this.props.algorithmSpecification)), - ...(this.renderInputDataConfig(this.props.inputDataConfig)), + RoleArn: this.role.roleArn, + ...(this.renderAlgorithmSpecification(this.algorithmSpecification)), + ...(this.renderInputDataConfig(this.inputDataConfig)), ...(this.renderOutputDataConfig(this.props.outputDataConfig)), - ...(this.renderResourceConfig(this.props.resourceConfig)), - ...(this.renderStoppingCondition(this.props.stoppingCondition)), + ...(this.renderResourceConfig(this.resourceConfig)), + ...(this.renderStoppingCondition(this.stoppingCondition)), ...(this.renderHyperparameters(this.props.hyperparameters)), ...(this.renderTags(this.props.tags)), ...(this.renderVpcConfig(this.props.vpcConfig)), @@ -189,7 +266,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT .addAllResources(), new iam.PolicyStatement() .addAction('iam:PassRole') - .addResources(this.props.role.roleArn) + .addResources(this.role.roleArn) .addCondition('StringEquals', { "iam:PassedToService": "sagemaker.amazonaws.com" }) ]; diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts index 9ff5848739ae9..bfd365f88a293 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts @@ -1,8 +1,12 @@ +import ec2 = require('@aws-cdk/aws-ec2'); import iam = require('@aws-cdk/aws-iam'); import sfn = require('@aws-cdk/aws-stepfunctions'); -import { Stack } from '@aws-cdk/cdk'; -import { BatchStrategy, TransformInput, TransformOutput, TransformResources } from './sagemaker-task-base-types'; +import { Construct, Stack } from '@aws-cdk/cdk'; +import { BatchStrategy, S3DataType, TransformInput, TransformOutput, TransformResources } from './sagemaker-task-base-types'; +/** + * @experimental + */ export interface SagemakerTransformProps { /** @@ -13,7 +17,7 @@ export interface SagemakerTransformProps { /** * Role for thte Training Job. */ - readonly role: iam.Role; + readonly role?: iam.IRole; /** * Specify if the task is synchronous or asychronous. @@ -63,15 +67,60 @@ export interface SagemakerTransformProps { /** * ML compute instances for the transform job. */ - readonly transformResources: TransformResources; + readonly transformResources?: TransformResources; } /** * Class representing the SageMaker Create Training Job task. + * + * @experimental */ export class SagemakerTransformTask implements sfn.IStepFunctionsTask { - constructor(private readonly props: SagemakerTransformProps) { } + /** + * The execution role for the Sagemaker training job. + * + * @default new role for Amazon SageMaker to assume is automatically created. + */ + public readonly role: iam.IRole; + + /** + * Dataset to be transformed and the Amazon S3 location where it is stored. + */ + private readonly transformInput: TransformInput; + + /** + * ML compute instances for the transform job. + */ + private readonly transformResources: TransformResources; + + constructor(scope: Construct, private readonly props: SagemakerTransformProps) { + + // set the sagemaker role or create new one + this.role = props.role || new iam.Role(scope, 'SagemakerRole', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicyArns: [ + new iam.AwsManagedPolicy('AmazonSageMakerFullAccess', scope).policyArn + ] + }); + + // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined + this.transformInput = (props.transformInput.transformDataSource.s3DataSource.s3DataType) ? (props.transformInput) : + Object.assign({}, props.transformInput, + { transformDataSource: + { s3DataSource: + { ...props.transformInput.transformDataSource.s3DataSource, + s3DataType: S3DataType.S3Prefix + } + } + }); + + // set the default value for the transform resources + this.transformResources = props.transformResources || { + instanceCount: 1, + instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.M4, ec2.InstanceSize.XLarge), + }; + } public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig { return { @@ -89,10 +138,10 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask { ...(this.props.maxPayloadInMB) ? { MaxPayloadInMB: this.props.maxPayloadInMB } : {}, ModelName: this.props.modelName, ...(this.renderTags(this.props.tags)), - ...(this.renderTransformInput(this.props.transformInput)), + ...(this.renderTransformInput(this.transformInput)), TransformJobName: this.props.transformJobName, ...(this.renderTransformOutput(this.props.transformOutput)), - ...(this.renderTransformResources(this.props.transformResources)), + ...(this.renderTransformResources(this.transformResources)), }; } @@ -158,7 +207,7 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask { .addAllResources(), new iam.PolicyStatement() .addAction('iam:PassRole') - .addResources(this.props.role.roleArn) + .addResources(this.role.roleArn) .addCondition('StringEquals', { "iam:PassedToService": "sagemaker.amazonaws.com" }) ]; diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts index a795a3c15f473..dd8de65a04552 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-training-job.test.ts @@ -7,34 +7,24 @@ import cdk = require('@aws-cdk/cdk'); import tasks = require('../lib'); let stack: cdk.Stack; -let role: iam.Role; beforeEach(() => { // GIVEN stack = new cdk.Stack(); - role = new iam.Role(stack, 'Role', { - assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), - managedPolicyArns: [ - new iam.AwsManagedPolicy('AmazonSageMakerFullAccess', stack).policyArn - ], - }); }); test('create basic training job', () => { // WHEN - const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask({ + const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask(stack, { trainingJobName: "MyTrainJob", - role, algorithmSpecification: { algorithmName: "BlazingText", - trainingInputMode: tasks.InputMode.File }, inputDataConfig: [ { channelName: 'train', dataSource: { s3DataSource: { - s3DataType: tasks.S3DataType.S3Prefix, s3Uri: "s3://mybucket/mytrainpath" } } @@ -43,14 +33,6 @@ test('create basic training job', () => { outputDataConfig: { s3OutputPath: 's3://mybucket/myoutputpath' }, - resourceConfig: { - instanceCount: 1, - instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), - volumeSizeInGB: 50 - }, - stoppingCondition: { - maxRuntimeInSeconds: 3600 - } })}); // THEN @@ -59,11 +41,9 @@ test('create basic training job', () => { Resource: 'arn:aws:states:::sagemaker:createTrainingJob', End: true, Parameters: { - TrainingJobName: 'MyTrainJob', - RoleArn: { "Fn::GetAtt": [ "Role1ABCC5F0", "Arn" ] }, AlgorithmSpecification: { - TrainingInputMode: 'File', AlgorithmName: 'BlazingText', + TrainingInputMode: 'File', }, InputDataConfig: [ { @@ -81,12 +61,14 @@ test('create basic training job', () => { }, ResourceConfig: { InstanceCount: 1, - InstanceType: 'ml.p3.2xlarge', - VolumeSizeInGB: 50 + InstanceType: 'ml.m4.xlarge', + VolumeSizeInGB: 10 }, + RoleArn: { "Fn::GetAtt": [ "SagemakerRole5FDB64E1", "Arn" ] }, StoppingCondition: { MaxRuntimeInSeconds: 3600 - } + }, + TrainingJobName: 'MyTrainJob', }, }); }); @@ -98,7 +80,14 @@ test('create complex training job', () => { const securityGroup = new ec2.SecurityGroup(stack, 'SecurityGroup', { vpc, description: 'My SG' }); securityGroup.addIngressRule(new ec2.AnyIPv4(), new ec2.TcpPort(22), 'allow ssh access from the world'); - const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask({ + const role = new iam.Role(stack, 'Role', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicyArns: [ + new iam.AwsManagedPolicy('AmazonSageMakerFullAccess', stack).policyArn + ], + }); + + const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask(stack, { trainingJobName: "MyTrainJob", synchronous: true, role, @@ -237,7 +226,14 @@ test('create complex training job', () => { test('pass param to training job', () => { // WHEN - const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask({ + const role = new iam.Role(stack, 'Role', { + assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'), + managedPolicyArns: [ + new iam.AwsManagedPolicy('AmazonSageMakerFullAccess', stack).policyArn + ], + }); + + const task = new sfn.Task(stack, 'TrainSagemaker', { task: new tasks.SagemakerTrainTask(stack, { trainingJobName: sfn.Data.stringAt('$.JobName'), role, algorithmSpecification: { diff --git a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts index 95dc160ff3633..e6a25b0f490dc 100644 --- a/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions-tasks/test/sagemaker-transform-job.test.ts @@ -23,25 +23,19 @@ beforeEach(() => { test('create basic transform job', () => { // WHEN - const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ + const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask(stack, { transformJobName: "MyTransformJob", modelName: "MyModelName", - role, transformInput: { transformDataSource: { s3DataSource: { s3Uri: 's3://inputbucket/prefix', - s3DataType: S3DataType.S3Prefix, } } }, transformOutput: { s3OutputPath: 's3://outputbucket/prefix', }, - transformResources: { - instanceCount: 1, - instanceType: new ec2.InstanceTypePair(ec2.InstanceClass.P3, ec2.InstanceSize.XLarge2), - } }) }); // THEN @@ -65,7 +59,7 @@ test('create basic transform job', () => { }, TransformResources: { InstanceCount: 1, - InstanceType: 'ml.p3.2xlarge', + InstanceType: 'ml.m4.xlarge', } }, }); @@ -74,7 +68,7 @@ test('create basic transform job', () => { test('create complex transform job', () => { // WHEN const kmsKey = new kms.Key(stack, 'Key'); - const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ + const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask(stack, { transformJobName: "MyTransformJob", modelName: "MyModelName", synchronous: true, @@ -147,7 +141,7 @@ test('create complex transform job', () => { test('pass param to transform job', () => { // WHEN - const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask({ + const task = new sfn.Task(stack, 'TransformTask', { task: new tasks.SagemakerTransformTask(stack, { transformJobName: sfn.Data.stringAt('$.TransformJobName'), modelName: sfn.Data.stringAt('$.ModelName'), role,