Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(stepfunctions): Downscope SageMaker permissions #2991

Merged
merged 12 commits into from
Jul 3, 2019
Merged
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import ec2 = require('@aws-cdk/aws-ec2');
import ecr = require('@aws-cdk/aws-ecr');
import iam = require('@aws-cdk/aws-iam');
import kms = require('@aws-cdk/aws-kms');
import s3 = require('@aws-cdk/aws-s3');
import sfn = require('@aws-cdk/aws-stepfunctions');
import { Duration } from '@aws-cdk/cdk';

//
Expand All @@ -24,7 +28,7 @@ export interface AlgorithmSpecification {
/**
* Registry path of the Docker image that contains the training algorithm.
*/
readonly trainingImage?: string;
readonly trainingImage?: IDockerImage;

/**
* Input mode that the algorithm supports.
Expand Down Expand Up @@ -125,7 +129,7 @@ export interface S3DataSource {
/**
* S3 Uri
*/
readonly s3Uri: string;
readonly s3Location: IS3Location;
}

/**
Expand All @@ -140,7 +144,7 @@ export interface OutputDataConfig {
/**
* Identifies the S3 path where you want Amazon SageMaker to store the model artifacts.
*/
readonly s3OutputPath: string;
readonly s3OutputLocation: IS3Location;
}

export interface StoppingCondition {
Expand Down Expand Up @@ -169,7 +173,7 @@ export interface ResourceConfig {
/**
* KMS key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) that run the training job.
*/
readonly volumeKmsKeyId?: kms.IKey;
readonly volumeEncryptionKey?: kms.IKey;

/**
* Size of the ML storage volume that you want to provision.
Expand Down Expand Up @@ -218,6 +222,148 @@ export interface MetricDefinition {
readonly regex: string;
}

/**
* Specifies a location in a S3 Bucket.
*
* @experimental
*/
export interface IS3Location {
RomainMuller marked this conversation as resolved.
Show resolved Hide resolved
/** The URI of the location in S3 */
readonly uri: string;

/** Grants read permissions to the S3 location. */
grantRead(grantee: iam.IGrantable): void;
/** Grants write permissions to the S3 location. */
grantWrite(grantee: iam.IGrantable): void;
}

/**
* Constructs `IS3Location` objects.
*/
export class S3Location {
/**
* An `IS3Location` built with a determined bucket and key prefix.
*
* @param bucket is the bucket where the objects are to be stored.
* @param keyPrefix is the key prefix used by the location.
*/
public static inBucket(bucket: s3.IBucket, keyPrefix: string): IS3Location {
return {
uri: bucket.urlForObject(keyPrefix),
grantRead: (grantee) => bucket.grantRead(grantee, keyPrefix + '*'),
grantWrite: (grantee) => bucket.grantWrite(grantee, keyPrefix + '*'),
};
}

/**
* An `IS3Location` determined fully by a JSON Path from the task input.
*
* Due to the dynamic nature of those locations, the IAM grants that will be set by `grantRead` and `grantWrite`
* apply to the `*` resource.
*
* @param expression the JSON expression resolving to an S3 location URI.
*/
public static fromJsonExpression(expression: string): IS3Location {
return {
uri: sfn.Data.stringAt(expression),
grantRead: grantee => grantee.grantPrincipal.addToPolicy(new iam.PolicyStatement({
actions: ['s3:GetObject', `s3:ListBucket`],
resources: ['*']
})),
grantWrite: grantee => grantee.grantPrincipal.addToPolicy(new iam.PolicyStatement({
actions: ['s3:PutObject'],
resources: ['*']
})),
};
}

private constructor() { }
}

export interface IDockerImage {
RomainMuller marked this conversation as resolved.
Show resolved Hide resolved
readonly name: string;
grantRead(grantee: iam.IGrantable): void;
RomainMuller marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Specifies options for accessing images in ECR.
*/
export interface EcrImageOptions {
/**
* The tag to use for this ECR Image. This option is mutually exclusive with `digest`.
*/
readonly tag?: string;

/**
* The digest to use for this ECR Image. This option is mutually exclusive with `tag`.
*/
readonly digest?: string;
}

/**
* Creates `IDockerImage` instances.
*/
export class DockerImage {
/**
* Reference a Docker image stored in an ECR repository.
*
* @param repo the ECR repository where the image is hosted.
* @param opts an optional `tag` or `digest` to use.
*/
public static inEcr(repo: ecr.IRepository, opts: EcrImageOptions = {}): IDockerImage {
RomainMuller marked this conversation as resolved.
Show resolved Hide resolved
if (opts.tag && opts.digest) {
throw new Error(`The tag and digest options are mutually exclusive, but both were specified`);
}
const suffix = opts.tag
? `:${opts.tag}`
: opts.digest
? `@${opts.digest}`
: '';
return {
name: repo.repositoryUri + suffix,
grantRead: repo.grantPull.bind(repo),
};
}

/**
* Reference a Docker image which URI is obtained from the task's input.
*
* @param expression the JSON path expression with the task input.
* @param enableEcrAccess whether ECR access should be permitted (set to `false` if the image will never be in ECR).
*/
public static fromJsonExpression(expression: string, enableEcrAccess = true): IDockerImage {
return this.fromImageUri(sfn.Data.stringAt(expression), enableEcrAccess);
}

/**
* Reference a Docker image by it's URI.
*
* When referencing ECR images, prefer using `inEcr`.
*
* @param uri the URI to the docker image.
* @param enableEcrAccess whether ECR access should be permitted (set to `true` if the image is located in an ECR
* repository).
*/
public static fromImageUri(uri: string, enableEcrAccess = false): IDockerImage {
RomainMuller marked this conversation as resolved.
Show resolved Hide resolved
return {
name: uri,
grantRead(grantee) {
if (!enableEcrAccess) { return; }
grantee.grantPrincipal.addToPolicy(new iam.PolicyStatement({
actions: [
'ecr:BatchCheckLayerAvailability',
'ecr:GetDownloadUrlForLayer',
'ecr:BatchGetImage'
],
resources: ['*']
}));
},
};
}

private constructor() { }
}

/**
* S3 Data Type.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@ import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceC
/**
* @experimental
*/
export interface SagemakerTrainProps {
export interface SagemakerTrainTaskProps {

/**
* Training Job Name.
*/
readonly trainingJobName: string;

/**
* Role for thte Training Job.
* Role for the Training Job. The role must be granted all necessary permissions for the SageMaker training job to
* be able to operate.
*
* See https://docs.aws.amazon.com/fr_fr/sagemaker/latest/dg/sagemaker-roles.html#sagemaker-roles-createtrainingjob-perms
*
* @default - a role with appropriate permissions will be created.
*/
readonly role?: iam.IRole;

Expand Down Expand Up @@ -71,7 +76,7 @@ export interface SagemakerTrainProps {
/**
* Class representing the SageMaker Create Training Job task.
*/
export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsTask {
export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn.IStepFunctionsTask {

/**
* Allows specify security group connections for instances of this fleet.
Expand Down Expand Up @@ -105,7 +110,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
*/
private readonly stoppingCondition: StoppingCondition;

constructor(scope: Construct, private readonly props: SagemakerTrainProps) {
constructor(scope: Construct, private readonly props: SagemakerTrainTaskProps) {

// set the default resource config if not defined.
this.resourceConfig = props.resourceConfig || {
Expand All @@ -122,11 +127,55 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
// set the sagemaker role or create new one
this.role = props.role || new iam.Role(scope, 'SagemakerRole', {
assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
managedPolicies: [
iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess')
]
inlinePolicies: {
CreateTrainingJob: new iam.PolicyDocument({
statements: [
new iam.PolicyStatement({
actions: [
'cloudwatch:PutMetricData',
'logs:CreateLogStream',
'logs:PutLogEvents',
'logs:CreateLogGroup',
'logs:DescribeLogStreams',
'ecr:GetAuthorizationToken',
...props.vpcConfig
? [
'ec2:CreateNetworkInterface',
'ec2:CreateNetworkInterfacePermission',
'ec2:DeleteNetworkInterface',
'ec2:DeleteNetworkInterfacePermission',
'ec2:DescribeNetworkInterfaces',
'ec2:DescribeVpcs',
'ec2:DescribeDhcpOptions',
'ec2:DescribeSubnets',
'ec2:DescribeSecurityGroups',
]
: [],
],
resources: ['*'], // Those permissions cannot be resource-scoped
})
]
}),
}
});

for (const input of props.inputDataConfig) {
input.dataSource.s3DataSource.s3Location.grantRead(this.role);
}
props.outputDataConfig.s3OutputLocation.grantWrite(this.role);

if (props.outputDataConfig.encryptionKey) {
props.outputDataConfig.encryptionKey.grantEncrypt(this.role);
}

if (props.algorithmSpecification.trainingImage) {
props.algorithmSpecification.trainingImage.grantRead(this.role);
}

if (props.resourceConfig && props.resourceConfig.volumeEncryptionKey) {
props.resourceConfig.volumeEncryptionKey.grant(this.role, 'kms:CreateGrant');
}

// set the input mode to 'File' if not defined
this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ?
( props.algorithmSpecification ) :
Expand All @@ -148,6 +197,10 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
}
}

public get grantPrincipal(): iam.IPrincipal {
RomainMuller marked this conversation as resolved.
Show resolved Hide resolved
return this.role;
}

public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig {
return {
resourceArn: 'arn:aws:states:::sagemaker:createTrainingJob' + (this.props.synchronous ? '.sync' : ''),
Expand Down Expand Up @@ -190,7 +243,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
ChannelName: channel.channelName,
DataSource: {
S3DataSource: {
S3Uri: channel.dataSource.s3DataSource.s3Uri,
S3Uri: channel.dataSource.s3DataSource.s3Location.uri,
S3DataType: channel.dataSource.s3DataSource.s3DataType,
...(channel.dataSource.s3DataSource.s3DataDistributionType) ?
{ S3DataDistributionType: channel.dataSource.s3DataSource.s3DataDistributionType} : {},
Expand All @@ -209,7 +262,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
private renderOutputDataConfig(config: OutputDataConfig): {[key: string]: any} {
return {
OutputDataConfig: {
S3OutputPath: config.s3OutputPath,
S3OutputPath: config.s3OutputLocation.uri,
...(config.encryptionKey) ? { KmsKeyId: config.encryptionKey.keyArn } : {},
}
};
Expand All @@ -221,7 +274,7 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
InstanceCount: config.instanceCount,
InstanceType: 'ml.' + config.instanceType,
VolumeSizeInGB: config.volumeSizeInGB,
...(config.volumeKmsKeyId) ? { VolumeKmsKeyId: config.volumeKmsKeyId.keyArn } : {},
...(config.volumeEncryptionKey) ? { VolumeKmsKeyId: config.volumeEncryptionKey.keyArn } : {},
}
};
}
Expand Down Expand Up @@ -260,7 +313,8 @@ export class SagemakerTrainTask implements ec2.IConnectable, sfn.IStepFunctionsT
stack.formatArn({
service: 'sagemaker',
resource: 'training-job',
resourceName: '*'
// If the job name comes from input, we cannot target the policy to a particular ARN prefix reliably...
resourceName: sfn.Data.isJsonPath(this.props.trainingJobName) ? '*' : `${this.props.trainingJobName}*`
})
],
}),
Expand Down
6 changes: 5 additions & 1 deletion packages/@aws-cdk/aws-stepfunctions-tasks/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@
"dependencies": {
"@aws-cdk/aws-cloudwatch": "^0.35.0",
"@aws-cdk/aws-ec2": "^0.35.0",
"@aws-cdk/aws-ecr": "^0.35.0",
"@aws-cdk/aws-ecs": "^0.35.0",
"@aws-cdk/aws-iam": "^0.35.0",
"@aws-cdk/aws-kms": "^0.35.0",
"@aws-cdk/aws-lambda": "^0.35.0",
"@aws-cdk/aws-s3": "^0.35.0",
"@aws-cdk/aws-sns": "^0.35.0",
"@aws-cdk/aws-sqs": "^0.35.0",
"@aws-cdk/aws-stepfunctions": "^0.35.0",
Expand All @@ -93,10 +95,12 @@
"peerDependencies": {
"@aws-cdk/aws-cloudwatch": "^0.35.0",
"@aws-cdk/aws-ec2": "^0.35.0",
"@aws-cdk/aws-ecr": "^0.35.0",
"@aws-cdk/aws-ecs": "^0.35.0",
"@aws-cdk/aws-iam": "^0.35.0",
"@aws-cdk/aws-kms": "^0.35.0",
"@aws-cdk/aws-lambda": "^0.35.0",
"@aws-cdk/aws-s3": "^0.35.0",
"@aws-cdk/aws-sns": "^0.35.0",
"@aws-cdk/aws-sqs": "^0.35.0",
"@aws-cdk/aws-stepfunctions": "^0.35.0",
Expand All @@ -106,4 +110,4 @@
"node": ">= 8.10.0"
},
"stability": "experimental"
}
}
Loading