Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(stepfunctions): Downscope SageMaker permissions #2991

Merged
merged 12 commits into from
Jul 3, 2019
Merged
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import ec2 = require('@aws-cdk/aws-ec2');
import ecr = require('@aws-cdk/aws-ecr');
import { DockerImageAsset, DockerImageAssetProps } from '@aws-cdk/aws-ecr-assets';
import iam = require('@aws-cdk/aws-iam');
import kms = require('@aws-cdk/aws-kms');
import { Duration } from '@aws-cdk/core';
import s3 = require('@aws-cdk/aws-s3');
import sfn = require('@aws-cdk/aws-stepfunctions');
import { Construct, Duration } from '@aws-cdk/core';

export interface ISageMakerTask extends sfn.IStepFunctionsTask, iam.IGrantable {}

//
// Create Training Job types
Expand All @@ -24,7 +31,7 @@ export interface AlgorithmSpecification {
/**
* Registry path of the Docker image that contains the training algorithm.
*/
readonly trainingImage?: string;
readonly trainingImage?: DockerImage;

/**
* Input mode that the algorithm supports.
Expand Down Expand Up @@ -125,7 +132,7 @@ export interface S3DataSource {
/**
* S3 Uri
*/
readonly s3Uri: string;
readonly s3Location: S3Location;
}

/**
Expand All @@ -140,16 +147,22 @@ export interface OutputDataConfig {
/**
* Identifies the S3 path where you want Amazon SageMaker to store the model artifacts.
*/
readonly s3OutputPath: string;
readonly s3OutputLocation: S3Location;
}

/**
* @experimental
*/
export interface StoppingCondition {
/**
* The maximum length of time, in seconds, that the training or compilation job can run.
*/
readonly maxRuntime?: Duration;
}

/**
* @experimental
*/
export interface ResourceConfig {

/**
Expand All @@ -169,7 +182,7 @@ export interface ResourceConfig {
/**
* KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) that run the training job.
*/
readonly volumeKmsKeyId?: kms.IKey;
readonly volumeEncryptionKey?: kms.IKey;

/**
* Size of the ML storage volume that you want to provision.
Expand Down Expand Up @@ -218,8 +231,139 @@ export interface MetricDefinition {
readonly regex: string;
}

/**
* @experimental
*/
export interface S3LocationConfig {
readonly uri: string;
}

/**
* Constructs `IS3Location` objects.
*
* @experimental
*/
export abstract class S3Location {
/**
* An `IS3Location` built with a determined bucket and key prefix.
*
* @param bucket is the bucket where the objects are to be stored.
* @param keyPrefix is the key prefix used by the location.
*/
public static fromBucket(bucket: s3.IBucket, keyPrefix: string): S3Location {
return new StandardS3Location({ bucket, keyPrefix, uri: bucket.urlForObject(keyPrefix) });
}

/**
* An `IS3Location` determined fully by a JSON Path from the task input.
*
* Due to the dynamic nature of those locations, the IAM grants that will be set by `grantRead` and `grantWrite`
* apply to the `*` resource.
*
* @param expression the JSON expression resolving to an S3 location URI.
*/
public static fromJsonExpression(expression: string): S3Location {
return new StandardS3Location({ uri: sfn.Data.stringAt(expression) });
}

/**
* Called when the S3Location is bound to a StepFunctions task.
*/
public abstract bind(task: ISageMakerTask, opts: S3LocationBindOptions): S3LocationConfig;
}

/**
* Options for binding an S3 Location.
*
* @experimental
*/
export interface S3LocationBindOptions {
/**
* Allow reading from the S3 Location.
*
* @default false
*/
readonly forReading?: boolean;

/**
* Allow writing to the S3 Location.
*
* @default false
*/
readonly forWriting?: boolean;
}

/**
* Configuration for a using Docker image.
*
* @experimental
*/
export interface DockerImageConfig {
/**
* The fully qualified URI of the Docker image.
*/
readonly imageUri: string;
}

/**
* Creates `IDockerImage` instances.
*
* @experimental
*/
export abstract class DockerImage {
/**
* Reference a Docker image stored in an ECR repository.
*
* @param repository the ECR repository where the image is hosted.
* @param tag an optional `tag`
*/
public static fromEcrRepository(repository: ecr.IRepository, tag: string = 'latest'): DockerImage {
return new StandardDockerImage({ repository, imageUri: repository.repositoryUriForTag(tag) });
}

/**
* Reference a Docker image which URI is obtained from the task's input.
*
* @param expression the JSON path expression with the task input.
* @param allowAnyEcrImagePull whether ECR access should be permitted (set to `false` if the image will never be in ECR).
*/
public static fromJsonExpression(expression: string, allowAnyEcrImagePull = true): DockerImage {
return new StandardDockerImage({ imageUri: expression, allowAnyEcrImagePull });
}

/**
* Reference a Docker image by it's URI.
*
* When referencing ECR images, prefer using `inEcr`.
*
* @param imageUri the URI to the docker image.
*/
public static fromRegistry(imageUri: string): DockerImage {
return new StandardDockerImage({ imageUri });
}

/**
* Reference a Docker image that is provided as an Asset in the current app.
*
* @param scope the scope in which to create the Asset.
* @param id the ID for the asset in the construct tree.
* @param props the configuration props of the asset.
*/
public static fromAsset(scope: Construct, id: string, props: DockerImageAssetProps): DockerImage {
const asset = new DockerImageAsset(scope, id, props);
return new StandardDockerImage({ repository: asset.repository, imageUri: asset.imageUri });
}

/**
* Called when the image is used by a SageMaker task.
*/
public abstract bind(task: ISageMakerTask): DockerImageConfig;
}

/**
* S3 Data Type.
*
* @experimental
*/
export enum S3DataType {
/**
Expand All @@ -240,6 +384,8 @@ export enum S3DataType {

/**
* S3 Data Distribution Type.
*
* @experimental
*/
export enum S3DataDistributionType {
/**
Expand All @@ -255,6 +401,8 @@ export enum S3DataDistributionType {

/**
* Define the format of the input data.
*
* @experimental
*/
export enum RecordWrapperType {
/**
Expand All @@ -270,6 +418,8 @@ export enum RecordWrapperType {

/**
* Input mode that the algorithm supports.
*
* @experimental
*/
export enum InputMode {
/**
Expand All @@ -285,6 +435,8 @@ export enum InputMode {

/**
* Compression type of the data.
*
* @experimental
*/
export enum CompressionType {
/**
Expand Down Expand Up @@ -416,6 +568,8 @@ export interface TransformResources {

/**
* Specifies the number of records to include in a mini-batch for an HTTP inference request.
*
* @experimental
*/
export enum BatchStrategy {

Expand All @@ -432,6 +586,8 @@ export enum BatchStrategy {

/**
* Method to use to split the transform job's data files into smaller batches.
*
* @experimental
*/
export enum SplitType {

Expand All @@ -458,6 +614,8 @@ export enum SplitType {

/**
* How to assemble the results of the transform job as a single S3 object.
*
* @experimental
*/
export enum AssembleWith {

Expand All @@ -472,3 +630,70 @@ export enum AssembleWith {
LINE = 'Line'

}

class StandardDockerImage extends DockerImage {
private readonly allowAnyEcrImagePull: boolean;
private readonly imageUri: string;
private readonly repository?: ecr.IRepository;

constructor(opts: { allowAnyEcrImagePull?: boolean, imageUri: string, repository?: ecr.IRepository }) {
super();

this.allowAnyEcrImagePull = !!opts.allowAnyEcrImagePull;
this.imageUri = opts.imageUri;
this.repository = opts.repository;
}

public bind(task: ISageMakerTask): DockerImageConfig {
if (this.repository) {
this.repository.grantPull(task);
}
if (this.allowAnyEcrImagePull) {
task.grantPrincipal.addToPolicy(new iam.PolicyStatement({
actions: [
'ecr:BatchCheckLayerAvailability',
'ecr:GetDownloadUrlForLayer',
'ecr:BatchGetImage',
],
resources: ['*']
}));
}
return {
imageUri: this.imageUri,
};
}
}

class StandardS3Location extends S3Location {
private readonly bucket?: s3.IBucket;
private readonly keyGlob: string;
private readonly uri: string;

constructor(opts: { bucket?: s3.IBucket, keyPrefix?: string, uri: string }) {
super();
this.bucket = opts.bucket;
this.keyGlob = `${opts.keyPrefix || ''}*`;
this.uri = opts.uri;
}

public bind(task: ISageMakerTask, opts: S3LocationBindOptions): S3LocationConfig {
if (this.bucket) {
if (opts.forReading) {
this.bucket.grantRead(task, this.keyGlob);
}
if (opts.forWriting) {
this.bucket.grantWrite(task, this.keyGlob);
}
} else {
const actions = new Array<string>();
if (opts.forReading) {
actions.push('s3:GetObject', 's3:ListBucket');
}
if (opts.forWriting) {
actions.push('s3:PutObject');
}
task.grantPrincipal.addToPolicy(new iam.PolicyStatement({ actions, resources: ['*'], }));
}
return { uri: this.uri };
}
}
Loading