Skip to content

Commit

Permalink
fix(stepfunctions): Downscope SageMaker permissions (aws#2991)
Browse files Browse the repository at this point in the history
Previous implementation was using the `SageMakerFullAccess` managed
policy, which grants extensive permissions to the SageMaker job.
Instead, this commit makes it set permissions very specifically to what
the requirement entities are, and only resorts to the `*` resource when
the entities are provided by an input to the StepFunction.
  • Loading branch information
RomainMuller authored Jul 3, 2019
1 parent 47bf435 commit 69c82c8
Show file tree
Hide file tree
Showing 10 changed files with 762 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import ec2 = require('@aws-cdk/aws-ec2');
import ecr = require('@aws-cdk/aws-ecr');
import { DockerImageAsset, DockerImageAssetProps } from '@aws-cdk/aws-ecr-assets';
import iam = require('@aws-cdk/aws-iam');
import kms = require('@aws-cdk/aws-kms');
import { Duration } from '@aws-cdk/core';
import s3 = require('@aws-cdk/aws-s3');
import sfn = require('@aws-cdk/aws-stepfunctions');
import { Construct, Duration } from '@aws-cdk/core';

export interface ISageMakerTask extends sfn.IStepFunctionsTask, iam.IGrantable {}

//
// Create Training Job types
Expand All @@ -24,7 +31,7 @@ export interface AlgorithmSpecification {
/**
* Registry path of the Docker image that contains the training algorithm.
*/
readonly trainingImage?: string;
readonly trainingImage?: DockerImage;

/**
* Input mode that the algorithm supports.
Expand Down Expand Up @@ -125,7 +132,7 @@ export interface S3DataSource {
/**
* S3 Uri
*/
readonly s3Uri: string;
readonly s3Location: S3Location;
}

/**
Expand All @@ -140,16 +147,22 @@ export interface OutputDataConfig {
/**
* Identifies the S3 path where you want Amazon SageMaker to store the model artifacts.
*/
readonly s3OutputPath: string;
readonly s3OutputLocation: S3Location;
}

/**
* @experimental
*/
export interface StoppingCondition {
/**
* The maximum length of time, in seconds, that the training or compilation job can run.
*/
readonly maxRuntime?: Duration;
}

/**
* @experimental
*/
export interface ResourceConfig {

/**
Expand All @@ -169,7 +182,7 @@ export interface ResourceConfig {
/**
* KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) that run the training job.
*/
readonly volumeKmsKeyId?: kms.IKey;
readonly volumeEncryptionKey?: kms.IKey;

/**
* Size of the ML storage volume that you want to provision.
Expand Down Expand Up @@ -218,8 +231,139 @@ export interface MetricDefinition {
readonly regex: string;
}

/**
* @experimental
*/
export interface S3LocationConfig {
readonly uri: string;
}

/**
* Constructs `IS3Location` objects.
*
* @experimental
*/
export abstract class S3Location {
/**
* An `IS3Location` built with a determined bucket and key prefix.
*
* @param bucket is the bucket where the objects are to be stored.
* @param keyPrefix is the key prefix used by the location.
*/
public static fromBucket(bucket: s3.IBucket, keyPrefix: string): S3Location {
return new StandardS3Location({ bucket, keyPrefix, uri: bucket.urlForObject(keyPrefix) });
}

/**
* An `IS3Location` determined fully by a JSON Path from the task input.
*
* Due to the dynamic nature of those locations, the IAM grants that will be set by `grantRead` and `grantWrite`
* apply to the `*` resource.
*
* @param expression the JSON expression resolving to an S3 location URI.
*/
public static fromJsonExpression(expression: string): S3Location {
return new StandardS3Location({ uri: sfn.Data.stringAt(expression) });
}

/**
* Called when the S3Location is bound to a StepFunctions task.
*/
public abstract bind(task: ISageMakerTask, opts: S3LocationBindOptions): S3LocationConfig;
}

/**
* Options for binding an S3 Location.
*
* @experimental
*/
export interface S3LocationBindOptions {
/**
* Allow reading from the S3 Location.
*
* @default false
*/
readonly forReading?: boolean;

/**
* Allow writing to the S3 Location.
*
* @default false
*/
readonly forWriting?: boolean;
}

/**
* Configuration for a using Docker image.
*
* @experimental
*/
export interface DockerImageConfig {
/**
* The fully qualified URI of the Docker image.
*/
readonly imageUri: string;
}

/**
* Creates `IDockerImage` instances.
*
* @experimental
*/
export abstract class DockerImage {
/**
* Reference a Docker image stored in an ECR repository.
*
* @param repository the ECR repository where the image is hosted.
* @param tag an optional `tag`
*/
public static fromEcrRepository(repository: ecr.IRepository, tag: string = 'latest'): DockerImage {
return new StandardDockerImage({ repository, imageUri: repository.repositoryUriForTag(tag) });
}

/**
* Reference a Docker image which URI is obtained from the task's input.
*
* @param expression the JSON path expression with the task input.
* @param allowAnyEcrImagePull whether ECR access should be permitted (set to `false` if the image will never be in ECR).
*/
public static fromJsonExpression(expression: string, allowAnyEcrImagePull = true): DockerImage {
return new StandardDockerImage({ imageUri: expression, allowAnyEcrImagePull });
}

/**
* Reference a Docker image by it's URI.
*
* When referencing ECR images, prefer using `inEcr`.
*
* @param imageUri the URI to the docker image.
*/
public static fromRegistry(imageUri: string): DockerImage {
return new StandardDockerImage({ imageUri });
}

/**
* Reference a Docker image that is provided as an Asset in the current app.
*
* @param scope the scope in which to create the Asset.
* @param id the ID for the asset in the construct tree.
* @param props the configuration props of the asset.
*/
public static fromAsset(scope: Construct, id: string, props: DockerImageAssetProps): DockerImage {
const asset = new DockerImageAsset(scope, id, props);
return new StandardDockerImage({ repository: asset.repository, imageUri: asset.imageUri });
}

/**
* Called when the image is used by a SageMaker task.
*/
public abstract bind(task: ISageMakerTask): DockerImageConfig;
}

/**
* S3 Data Type.
*
* @experimental
*/
export enum S3DataType {
/**
Expand All @@ -240,6 +384,8 @@ export enum S3DataType {

/**
* S3 Data Distribution Type.
*
* @experimental
*/
export enum S3DataDistributionType {
/**
Expand All @@ -255,6 +401,8 @@ export enum S3DataDistributionType {

/**
* Define the format of the input data.
*
* @experimental
*/
export enum RecordWrapperType {
/**
Expand All @@ -270,6 +418,8 @@ export enum RecordWrapperType {

/**
* Input mode that the algorithm supports.
*
* @experimental
*/
export enum InputMode {
/**
Expand All @@ -285,6 +435,8 @@ export enum InputMode {

/**
* Compression type of the data.
*
* @experimental
*/
export enum CompressionType {
/**
Expand Down Expand Up @@ -416,6 +568,8 @@ export interface TransformResources {

/**
* Specifies the number of records to include in a mini-batch for an HTTP inference request.
*
* @experimental
*/
export enum BatchStrategy {

Expand All @@ -432,6 +586,8 @@ export enum BatchStrategy {

/**
* Method to use to split the transform job's data files into smaller batches.
*
* @experimental
*/
export enum SplitType {

Expand All @@ -458,6 +614,8 @@ export enum SplitType {

/**
* How to assemble the results of the transform job as a single S3 object.
*
* @experimental
*/
export enum AssembleWith {

Expand All @@ -472,3 +630,70 @@ export enum AssembleWith {
LINE = 'Line'

}

class StandardDockerImage extends DockerImage {
private readonly allowAnyEcrImagePull: boolean;
private readonly imageUri: string;
private readonly repository?: ecr.IRepository;

constructor(opts: { allowAnyEcrImagePull?: boolean, imageUri: string, repository?: ecr.IRepository }) {
super();

this.allowAnyEcrImagePull = !!opts.allowAnyEcrImagePull;
this.imageUri = opts.imageUri;
this.repository = opts.repository;
}

public bind(task: ISageMakerTask): DockerImageConfig {
if (this.repository) {
this.repository.grantPull(task);
}
if (this.allowAnyEcrImagePull) {
task.grantPrincipal.addToPolicy(new iam.PolicyStatement({
actions: [
'ecr:BatchCheckLayerAvailability',
'ecr:GetDownloadUrlForLayer',
'ecr:BatchGetImage',
],
resources: ['*']
}));
}
return {
imageUri: this.imageUri,
};
}
}

class StandardS3Location extends S3Location {
private readonly bucket?: s3.IBucket;
private readonly keyGlob: string;
private readonly uri: string;

constructor(opts: { bucket?: s3.IBucket, keyPrefix?: string, uri: string }) {
super();
this.bucket = opts.bucket;
this.keyGlob = `${opts.keyPrefix || ''}*`;
this.uri = opts.uri;
}

public bind(task: ISageMakerTask, opts: S3LocationBindOptions): S3LocationConfig {
if (this.bucket) {
if (opts.forReading) {
this.bucket.grantRead(task, this.keyGlob);
}
if (opts.forWriting) {
this.bucket.grantWrite(task, this.keyGlob);
}
} else {
const actions = new Array<string>();
if (opts.forReading) {
actions.push('s3:GetObject', 's3:ListBucket');
}
if (opts.forWriting) {
actions.push('s3:PutObject');
}
task.grantPrincipal.addToPolicy(new iam.PolicyStatement({ actions, resources: ['*'], }));
}
return { uri: this.uri };
}
}
Loading

0 comments on commit 69c82c8

Please sign in to comment.