Skip to content

Commit

Permalink
Sagemaker Return values
Browse files Browse the repository at this point in the history
  • Loading branch information
biffgaut committed Feb 18, 2023
1 parent dbaa181 commit c5c048b
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,16 @@ export class LambdaToSagemakerEndpoint extends Construct {
}

// Build SageMaker Endpoint (inclduing SageMaker's Endpoint Configuration and Model)
[this.sagemakerEndpoint, this.sagemakerEndpointConfig, this.sagemakerModel] = defaults.BuildSagemakerEndpoint(
const buildSagemakerEndpointResponse = defaults.BuildSagemakerEndpoint(
this,
{
...props,
vpc: this.vpc,
}
);
this.sagemakerEndpoint = buildSagemakerEndpointResponse.endpoint;
this.sagemakerEndpointConfig = buildSagemakerEndpointResponse.endpointConfig;
this.sagemakerModel = buildSagemakerEndpointResponse.model;

// Setup the Lambda function
this.lambdaFunction = defaults.buildLambdaFunction(this, {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ test('Test getter methods: new Lambda function, existingSagemakerendpointObj (no
// Initial Setup
const stack = new Stack();

const [sagemakerEndpoint] = defaults.deploySagemakerEndpoint(stack, {
const deploySagemakerEndpointResponse = defaults.deploySagemakerEndpoint(stack, {
modelProps: {
primaryContainer: {
image: '<AccountId>.dkr.ecr.<region>.amazonaws.com/linear-learner:latest',
Expand All @@ -484,7 +484,7 @@ test('Test getter methods: new Lambda function, existingSagemakerendpointObj (no
});

const constructProps: LambdaToSagemakerEndpointProps = {
existingSagemakerEndpointObj: sagemakerEndpoint,
existingSagemakerEndpointObj: deploySagemakerEndpointResponse.endpoint,
lambdaFunctionProps: {
runtime: lambda.Runtime.PYTHON_3_8,
code: lambda.Code.fromAsset(`${__dirname}/lambda`),
Expand All @@ -509,7 +509,7 @@ test('Test getter methods: new Lambda function, existingSagemakerendpointObj and
// Initial Setup
const stack = new Stack();

const [sagemakerEndpoint] = defaults.deploySagemakerEndpoint(stack, {
const deploySagemakerEndpointResponse = defaults.deploySagemakerEndpoint(stack, {
modelProps: {
primaryContainer: {
image: '<AccountId>.dkr.ecr.<region>.amazonaws.com/linear-learner:latest',
Expand All @@ -519,7 +519,7 @@ test('Test getter methods: new Lambda function, existingSagemakerendpointObj and
});

const constructProps: LambdaToSagemakerEndpointProps = {
existingSagemakerEndpointObj: sagemakerEndpoint,
existingSagemakerEndpointObj: deploySagemakerEndpointResponse.endpoint,
lambdaFunctionProps: {
runtime: lambda.Runtime.PYTHON_3_8,
code: lambda.Code.fromAsset(`${__dirname}/lambda`),
Expand All @@ -546,7 +546,7 @@ test('Test lambda function custom environment variable', () => {
const stack = new Stack();

// Helper declaration
const [sagemakerEndpoint] = defaults.deploySagemakerEndpoint(stack, {
const deploySagemakerEndpointResponse = defaults.deploySagemakerEndpoint(stack, {
modelProps: {
primaryContainer: {
image: '<AccountId>.dkr.ecr.<region>.amazonaws.com/linear-learner:latest',
Expand All @@ -555,7 +555,7 @@ test('Test lambda function custom environment variable', () => {
},
});
new LambdaToSagemakerEndpoint(stack, 'test-lambda-sagemaker', {
existingSagemakerEndpointObj: sagemakerEndpoint,
existingSagemakerEndpointObj: deploySagemakerEndpointResponse.endpoint,
lambdaFunctionProps: {
runtime: lambda.Runtime.PYTHON_3_8,
handler: 'index.handler',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ stack.templateOptions.description = 'Integration Test for aws-lambda-sagemakeren

const model = getSagemakerModel(stack);

const [sagemakerEndpoint, endpointConfig, mModel] = defaults.deploySagemakerEndpoint(stack, {
const deploySagemakerEndpointResponse = defaults.deploySagemakerEndpoint(stack, {
modelProps: {
primaryContainer: {
image: model.mapping.findInMap(Stack.of(stack).region, "containerArn"),
Expand All @@ -35,12 +35,12 @@ const [sagemakerEndpoint, endpointConfig, mModel] = defaults.deploySagemakerEndp
},
});

sagemakerEndpoint.node.addDependency(model.asset);
endpointConfig?.node.addDependency(model.asset);
mModel?.node.addDependency(model.asset);
deploySagemakerEndpointResponse.endpoint.node.addDependency(model.asset);
deploySagemakerEndpointResponse.endpointConfig?.node.addDependency(model.asset);
deploySagemakerEndpointResponse.model?.node.addDependency(model.asset);

const constructProps: LambdaToSagemakerEndpointProps = {
existingSagemakerEndpointObj: sagemakerEndpoint,
existingSagemakerEndpointObj: deploySagemakerEndpointResponse.endpoint,
lambdaFunctionProps: {
runtime: lambda.Runtime.PYTHON_3_8,
code: lambda.Code.fromAsset(`${__dirname}/lambda`),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ export interface BuildSagemakerNotebookProps {
readonly role: iam.Role;
}

function addPermissions(_role: iam.Role, props?: BuildSagemakerEndpointProps) {
function addPermissions(role: iam.Role, props?: BuildSagemakerEndpointProps) {
// Grant permissions to NoteBookInstance for creating and training the model
_role.addToPolicy(
role.addToPolicy(
new iam.PolicyStatement({
resources: [`arn:${Aws.PARTITION}:sagemaker:${Aws.REGION}:${Aws.ACCOUNT_ID}:*`],
actions: [
Expand All @@ -81,7 +81,7 @@ function addPermissions(_role: iam.Role, props?: BuildSagemakerEndpointProps) {
);

// Grant CloudWatch Logging permissions
_role.addToPolicy(
role.addToPolicy(
new iam.PolicyStatement({
resources: [`arn:${cdk.Aws.PARTITION}:logs:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:log-group:/aws/sagemaker/*`],
actions: [
Expand All @@ -96,7 +96,7 @@ function addPermissions(_role: iam.Role, props?: BuildSagemakerEndpointProps) {

// To place the Sagemaker endpoint in a VPC
if (props && props.vpc) {
_role.addToPolicy(
role.addToPolicy(
new iam.PolicyStatement({
resources: ['*'],
actions: [
Expand All @@ -118,7 +118,7 @@ function addPermissions(_role: iam.Role, props?: BuildSagemakerEndpointProps) {

// To create a Sagemaker model using Bring-Your-Own-Model (BYOM) algorith image
// The image URL is specified in the modelProps
_role.addToPolicy(
role.addToPolicy(
new iam.PolicyStatement({
resources: [`arn:${cdk.Aws.PARTITION}:ecr:${cdk.Aws.REGION}:${cdk.Aws.ACCOUNT_ID}:repository/*`],
actions: [
Expand All @@ -132,7 +132,7 @@ function addPermissions(_role: iam.Role, props?: BuildSagemakerEndpointProps) {
);

// Add GetAuthorizationToken (it can not be bound to resources other than *)
_role.addToPolicy(
role.addToPolicy(
new iam.PolicyStatement({
resources: ['*'],
actions: ['ecr:GetAuthorizationToken'],
Expand All @@ -145,7 +145,7 @@ function addPermissions(_role: iam.Role, props?: BuildSagemakerEndpointProps) {
const acceleratorType = (props.endpointConfigProps
?.productionVariants as sagemaker.CfnEndpointConfig.ProductionVariantProperty[])[0].acceleratorType;
if (acceleratorType !== undefined) {
_role.addToPolicy(
role.addToPolicy(
new iam.PolicyStatement({
resources: ['*'],
actions: ['elastic-inference:Connect'],
Expand All @@ -155,7 +155,7 @@ function addPermissions(_role: iam.Role, props?: BuildSagemakerEndpointProps) {
}

// add kms permissions
_role.addToPolicy(
role.addToPolicy(
new iam.PolicyStatement({
// the kmsKeyId in the endpointConfigProps can be any of the following formats:
// Key ID: 1234abcd-12ab-34cd-56ef-1234567890ab
Expand All @@ -172,25 +172,25 @@ function addPermissions(_role: iam.Role, props?: BuildSagemakerEndpointProps) {
);

// Add S3 permissions to get Model artifact, put data capture files, etc.
_role.addToPolicy(
role.addToPolicy(
new iam.PolicyStatement({
actions: ['s3:GetObject', 's3:PutObject', 's3:DeleteObject', 's3:ListBucket'],
resources: ['arn:aws:s3:::*'],
})
);

// Grant GetRole permissions to the Sagemaker service
_role.addToPolicy(
role.addToPolicy(
new iam.PolicyStatement({
resources: [_role.roleArn],
resources: [role.roleArn],
actions: ['iam:GetRole'],
})
);

// Grant PassRole permissions to the Sagemaker service
_role.addToPolicy(
role.addToPolicy(
new iam.PolicyStatement({
resources: [_role.roleArn],
resources: [role.roleArn],
actions: ['iam:PassRole'],
conditions: {
StringLike: { 'iam:PassedToService': 'sagemaker.amazonaws.com' },
Expand All @@ -201,7 +201,7 @@ function addPermissions(_role: iam.Role, props?: BuildSagemakerEndpointProps) {
// Add CFN NAG uppress to allow for "Resource": "*" for ENI access in VPC,
// ECR authorization token for custom model images, and elastic inference
// Add CFN NAG for Complex Role because Sagmaker needs permissions to access several services
const roleDefaultPolicy = _role.node.tryFindChild('DefaultPolicy')?.node.findChild('Resource') as iam.CfnPolicy;
const roleDefaultPolicy = role.node.tryFindChild('DefaultPolicy')?.node.findChild('Resource') as iam.CfnPolicy;
addCfnSuppressRules(roleDefaultPolicy, [
{
id: 'W12',
Expand All @@ -214,10 +214,16 @@ function addPermissions(_role: iam.Role, props?: BuildSagemakerEndpointProps) {
]);
}

export interface BuildSagemakerNotebookResponse {
readonly notebook: sagemaker.CfnNotebookInstance,
readonly vpc?: ec2.IVpc,
readonly securityGroup?: ec2.SecurityGroup
}

export function buildSagemakerNotebook(
scope: Construct,
props: BuildSagemakerNotebookProps
): [sagemaker.CfnNotebookInstance, ec2.IVpc?, ec2.SecurityGroup?] {
): BuildSagemakerNotebookResponse {
// Setup the notebook properties
let sagemakerNotebookProps;
let vpcInstance;
Expand Down Expand Up @@ -287,13 +293,13 @@ export function buildSagemakerNotebook(
sagemakerNotebookProps
);
if (vpcInstance) {
return [sagemakerInstance, vpcInstance, securityGroup];
return { notebook: sagemakerInstance, vpc: vpcInstance, securityGroup };
} else {
return [sagemakerInstance];
return { notebook: sagemakerInstance };
}
} else {
// Return existing notebook object
return [props.existingNotebookObj];
return { notebook: props.existingNotebookObj };
}
}

Expand Down Expand Up @@ -330,28 +336,40 @@ export interface BuildSagemakerEndpointProps {
readonly vpc?: ec2.IVpc;
}

export interface BuildSagemakerEndpointResponse {
readonly endpoint: sagemaker.CfnEndpoint,
readonly endpointConfig?: sagemaker.CfnEndpointConfig,
readonly model?: sagemaker.CfnModel
}

export function BuildSagemakerEndpoint(
scope: Construct,
props: BuildSagemakerEndpointProps
): [sagemaker.CfnEndpoint, sagemaker.CfnEndpointConfig?, sagemaker.CfnModel?] {
): BuildSagemakerEndpointResponse {
/** Conditional Sagemaker endpoint creation */
if (!props.existingSagemakerEndpointObj) {
if (props.modelProps) {
/** return [endpoint, endpointConfig, model] */
return deploySagemakerEndpoint(scope, props);
const deploySagemakerEndpointResponse = deploySagemakerEndpoint(scope, props);
return { ...deploySagemakerEndpointResponse };
} else {
throw Error('Either existingSagemakerEndpointObj or at least modelProps is required');
}
} else {
/** Otherwise, return [endpoint] */
return [props.existingSagemakerEndpointObj];
return { endpoint: props.existingSagemakerEndpointObj };
}
}

export interface DeploySagemakerEndpointResponse {
readonly endpoint: sagemaker.CfnEndpoint,
readonly endpointConfig?: sagemaker.CfnEndpointConfig,
readonly model?: sagemaker.CfnModel
}

export function deploySagemakerEndpoint(
scope: Construct,
props: BuildSagemakerEndpointProps
): [sagemaker.CfnEndpoint, sagemaker.CfnEndpointConfig?, sagemaker.CfnModel?] {
): DeploySagemakerEndpointResponse {
let model: sagemaker.CfnModel;
let endpointConfig: sagemaker.CfnEndpointConfig;
let endpoint: sagemaker.CfnEndpoint;
Expand Down Expand Up @@ -386,7 +404,7 @@ export function deploySagemakerEndpoint(
// Add dependency on EndpointConfig
endpoint.addDependency(endpointConfig);

return [endpoint, endpointConfig, model];
return { endpoint, endpointConfig, model };
} else {
throw Error('You need to provide at least modelProps to create Sagemaker Endpoint');
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,10 @@ test('Test fail SageMaker endpoint check', () => {
},
};

const [endpoint] = BuildSagemakerEndpoint(stack, { modelProps });
const buildSagemakerEndpointResponse = BuildSagemakerEndpoint(stack, { modelProps });

const props: defaults.VerifiedProps = {
existingSagemakerEndpointObj: endpoint,
existingSagemakerEndpointObj: buildSagemakerEndpointResponse.endpoint,
endpointProps: {
endpointConfigName: 'placeholder'
}
Expand Down
Loading

0 comments on commit c5c048b

Please sign in to comment.