Skip to content

Commit

Permalink
feat(sagemaker): add model hosting L2 constructs
Browse files Browse the repository at this point in the history
Based closely on PR aws#2888, this commit introduces the Endpoint L2
construct and the L2 constructs on which it depends, including Model and
EndpointConfig. Departures from PR aws#2888 include:
- EndpointConfig definition moved into its own L2 construct as one
  EndpointConfig resource may be shared by multiple Endpoints.
- An Endpoint-specific IEndpointProductionVariant interface was added to
  support "metric*" and "autoScale*" APIs per endpoint-variant
  combination.
- The Notebook construct was excluded from this new commit to limit
  changes to model hosting use-cases only.
- Feedback on the earlier PR has been incorporated into this new commit.

fixes aws#2809

Co-authored-by: Matt McClean <mmcclean@amazon.com>
Co-authored-by: Yao <longyao@amazon.com>
Co-authored-by: Drew Jetter <60628154+jetterdj@users.noreply.github.com>
Co-authored-by: Murali Ganesh <59461079+foxpro24@users.noreply.github.com>
Co-authored-by: Abilash Rangoju <988529+rangoju@users.noreply.github.com>
  • Loading branch information
6 people committed Feb 4, 2020
1 parent 155b80e commit c2b8d37
Show file tree
Hide file tree
Showing 22 changed files with 4,844 additions and 19 deletions.
77 changes: 77 additions & 0 deletions packages/@aws-cdk/aws-sagemaker/lib/container-image.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import * as ecr from '@aws-cdk/aws-ecr';
import * as assets from "@aws-cdk/aws-ecr-assets";
import * as cdk from '@aws-cdk/core';
import { Model } from './model';

/**
* The configuration for creating a container image.
*/
export interface ContainerImageConfig {
/**
* The image name. Images in Amazon ECR repositories can be specified by either using the full registry/repository:tag or
* registry/repository@digest.
*
* For example, 012345678910.dkr.ecr.<region-name>.amazonaws.com/<repository-name>:latest or
* 012345678910.dkr.ecr.<region-name>.amazonaws.com/<repository-name>@sha256:94afd1f2e64d908bc90dbca0035a5b567EXAMPLE.
*/
readonly imageName: string;
}

/**
* Constructs for types of container images
*/
export abstract class ContainerImage {
/**
* Reference an image in an ECR repository
*/
public static fromEcrRepository(repository: ecr.IRepository, tag: string = 'latest'): ContainerImage {
return new EcrImage(repository, tag);
}

/**
* Reference an image that's constructed directly from sources on disk
*
* @param scope The scope within which to create the image asset
* @param id The id to assign to the image asset
* @param props The properties of a Docker image asset
*/
public static fromAsset(scope: cdk.Construct, id: string, props: assets.DockerImageAssetProps): ContainerImage {
return new AssetImage(scope, id, props);
}

/**
* Called when the image is used by a Model
*/
public abstract bind(scope: cdk.Construct, model: Model): ContainerImageConfig;
}

class EcrImage extends ContainerImage {
constructor(private readonly repository: ecr.IRepository, private readonly tag: string) {
super();
}

public bind(_scope: cdk.Construct, model: Model): ContainerImageConfig {
this.repository.grantPull(model);

return {
imageName: this.repository.repositoryUriForTag(this.tag)
};
}
}

class AssetImage extends ContainerImage {
private readonly asset: assets.DockerImageAsset;

constructor(readonly scope: cdk.Construct, readonly id: string, readonly props: assets.DockerImageAssetProps) {
super();
this.asset = new assets.DockerImageAsset(scope, id, props);
}

public bind(_scope: cdk.Construct, model: Model): ContainerImageConfig {
this.asset.repository.grantPull(model);

return {
imageName: this.asset.imageUri,
};
}
}
306 changes: 306 additions & 0 deletions packages/@aws-cdk/aws-sagemaker/lib/endpoint-config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
import * as ec2 from '@aws-cdk/aws-ec2';
import * as kms from '@aws-cdk/aws-kms';
import * as cdk from '@aws-cdk/core';
import { EOL } from 'os';
import { CfnEndpointConfig } from '.';
import { IModel } from './model';

/**
* The interface for a SageMaker EndpointConfig resource.
*/
export interface IEndpointConfig extends cdk.IResource {
/**
* The ARN of the endpoint configuration.
*
* @attribute
*/
readonly endpointConfigArn: string;
/**
* The name of the endpoint configuration.
*
* @attribute
*/
readonly endpointConfigName: string;
}

/**
* Construction properties for a production variant.
*/
export interface ProductionVariantProps {
/**
* The size of the Elastic Inference (EI) instance to use for the production variant. EI instances
* provide on-demand GPU computing for inference.
*
* @default none
*/
readonly acceleratorType?: AcceleratorType;
/**
* Number of instances to launch initially.
*
* @default 1
*/
readonly initialInstanceCount?: number;
/**
* Determines initial traffic distribution among all of the models that you specify in the
* endpoint configuration. The traffic to a production variant is determined by the ratio of the
* variant weight to the sum of all variant weight values across all production variants.
*
* @default 1.0
*/
readonly initialVariantWeight?: number;
/**
* Instance type of the production variant.
*
* @default ml.t2.medium instance type.
*/
readonly instanceType?: ec2.InstanceType;
/**
* The model to host.
*/
readonly model: IModel;
/**
* Name of the production variant.
*/
readonly variantName: string;
}

/**
* Represents a production variant that has been associated with an EndpointConfig.
*/
export interface ProductionVariant {
/**
* The size of the Elastic Inference (EI) instance to use for the production variant. EI instances
* provide on-demand GPU computing for inference.
*
* @default none
*/
readonly acceleratorType?: AcceleratorType;
/**
* Number of instances to launch initially.
*/
readonly initialInstanceCount: number;
/**
* Determines initial traffic distribution among all of the models that you specify in the
* endpoint configuration. The traffic to a production variant is determined by the ratio of the
* variant weight to the sum of all variant weight values across all production variants.
*/
readonly initialVariantWeight: number;
/**
* Instance type of the production variant.
*/
readonly instanceType: ec2.InstanceType;
/**
* The name of the model to host.
*/
readonly modelName: string;
/**
* The name of the production variant.
*/
readonly variantName: string;
}

/**
* Name tag constant
*/
const NAME_TAG: string = 'Name';

/**
* Construction properties for a SageMaker EndpointConfig.
*/
export interface EndpointConfigProps {
/**
* Name of the endpoint configuration.
*
* @default AWS CloudFormation generates a unique physical ID and uses that ID for the endpoint
* configuration's name.
*/
readonly endpointConfigName?: string;

/**
* Optional KMS encryption key associated with this stream.
*
* @default none
*/
readonly encryptionKey?: kms.IKey;

/**
* A ProductionVariantProps object.
*/
readonly productionVariant: ProductionVariantProps;

/**
* An optional list of extra ProductionVariantProps objects.
*
* @default none
*/
readonly extraProductionVariants?: ProductionVariantProps[];
}

/**
* The size of the Elastic Inference (EI) instance to use for the production variant. EI instances
* provide on-demand GPU computing for inference.
*/
export enum AcceleratorType {
/**
* Medium accelerator type.
*/
MEDIUM = 'ml.eia1.medium',
/**
* Large accelerator type.
*/
LARGE = 'ml.eia1.large ',
/**
* Extra large accelerator type.
*/
XLARGE = 'ml.eia1.xlarge',
}

/**
* Defines a SageMaker EndpointConfig.
*/
export class EndpointConfig extends cdk.Resource implements IEndpointConfig {
/**
* Imports an EndpointConfig defined either outside the CDK or in a different CDK stack.
* @param scope the Construct scope.
* @param id the resource id.
* @param endpointConfigName the name of the endpoint configuration.
*/
public static fromEndpointConfigName(scope: cdk.Construct, id: string, endpointConfigName: string): IEndpointConfig {
class Import extends cdk.Resource implements IEndpointConfig {
public endpointConfigName = endpointConfigName;
public endpointConfigArn = cdk.Stack.of(this).formatArn({
service: 'sagemaker',
resource: 'endpoint-config',
resourceName: this.endpointConfigName
});
}

return new Import(scope, id);
}

/**
* The ARN of the endpoint configuration.
*/
public readonly endpointConfigArn: string;
/**
* The name of the endpoint configuration.
*/
public readonly endpointConfigName: string;

private readonly _productionVariants: { [key: string]: ProductionVariant; } = {};

constructor(scope: cdk.Construct, id: string, props: EndpointConfigProps) {
super(scope, id, {
physicalName: props.endpointConfigName
});

// apply a name tag to the endpoint config resource
this.node.applyAspect(new cdk.Tag(NAME_TAG, this.node.path));

[props.productionVariant, ...props.extraProductionVariants || []].map(p => this.addProductionVariant(p));

// create the endpoint configuration resource
const endpointConfig = new CfnEndpointConfig(this, 'EndpointConfig', {
kmsKeyId: (props.encryptionKey) ? props.encryptionKey.keyArn : undefined,
endpointConfigName: this.physicalName,
productionVariants: cdk.Lazy.anyValue({ produce: () => this.renderProductionVariants() })
});
this.endpointConfigName = this.getResourceNameAttribute(endpointConfig.attrEndpointConfigName);
this.endpointConfigArn = this.getResourceArnAttribute(endpointConfig.ref, {
service: 'sagemaker',
resource: 'endpoint-config',
resourceName: this.physicalName,
});
}

/**
* Add production variant to the endpoint configuration.
*
* @param props The properties of a production variant to add.
*/
public addProductionVariant(props: ProductionVariantProps): void {
if (props.variantName in this._productionVariants) {
throw new Error(`There is already a Production Variant with name '${props.variantName}'`);
}
this.validateProps(props);
this._productionVariants[props.variantName] = {
acceleratorType: props.acceleratorType,
initialInstanceCount: props.initialInstanceCount || 1,
initialVariantWeight: props.initialVariantWeight || 1.0,
instanceType: props.instanceType || ec2.InstanceType.of(ec2.InstanceClass.T2, ec2.InstanceSize.MEDIUM),
modelName: props.model.modelName,
variantName: props.variantName
};
}

/**
* Get production variants associated with endpoint configuration.
*/
public get productionVariants(): ProductionVariant[] {
return Object.values(this._productionVariants);
}

/**
* Find production variant based on variant name
* @param name Variant name from production variant
*/
public findProductionVariant(name: string): ProductionVariant {
const ret = this._productionVariants[name];
if (!ret) {
throw new Error(`No variant with name: '${name}'`);
}
return ret;
}

protected validate(): string[] {
const result = super.validate();
// check we have 10 or fewer production variants
if (this.productionVariants.length > 10) {
result.push('Can\'t have more than 10 Production Variants');
}

return result;
}

private validateProps(props: ProductionVariantProps): void {
const errors: string[] = [];
// check instance count is greater than zero
if (props.initialInstanceCount !== undefined && props.initialInstanceCount < 1) {
errors.push('Must have at least one instance');
}

// check variant weight is not negative
if (props.initialVariantWeight && props.initialVariantWeight < 0) {
errors.push('Cannot have negative variant weight');
}

// validate the instance type
if (props.instanceType) {
// check if a valid SageMaker instance type
const instanceType = props.instanceType.toString();
if (!['c4', 'c5', 'c5d', 'g4dn', 'inf1', 'm4', 'm5', 'm5d', 'p2', 'p3', 'r5', 'r5d', 't2']
.some(instanceClass => instanceType.indexOf(instanceClass) >= 0)) {
errors.push(`Invalid instance type for a SageMaker Endpoint Production Variant: ${instanceType}`);
}
}

if (errors.length > 0) {
throw new Error(`Invalid Production Variant Props: ${errors.join(EOL)}`);
}
}

/**
* Render the list of production variants.
*/
private renderProductionVariants(): CfnEndpointConfig.ProductionVariantProperty[] {
return this.productionVariants.map( v => ({
acceleratorType: v.acceleratorType,
initialInstanceCount: v.initialInstanceCount,
initialVariantWeight: v.initialVariantWeight,
instanceType: 'ml.' + v.instanceType.toString(),
modelName: v.modelName,
variantName: v.variantName,
}) );
}

}
Loading

0 comments on commit c2b8d37

Please sign in to comment.