Skip to content

Commit

Permalink
feat(eks): trainium instance types (aws#29155)
Browse files Browse the repository at this point in the history
@freschri – It's a little hard to find docs on this but I think this is what you're after?

Closes aws#29131.

----

*By submitting this pull request, I confirm that my contribution is made under the terms of the Apache-2.0 license*
  • Loading branch information
msambol authored Mar 15, 2024
1 parent 98e9fbe commit 507b709
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 9 deletions.
4 changes: 2 additions & 2 deletions packages/aws-cdk-lib/aws-eks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ cluster.addNodegroupCapacity('custom-node-group', {
});
```

> **NOTE:** If you add instances with the inferentia (`inf1` or `inf2`) class the
> [neuron plugin](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/containers/dlc-then-eks-devflow.html)
> **NOTE:** If you add instances with the inferentia class (`inf1` or `inf2`) or trainium class (`trn1` or `trn1n`)
> the [neuron plugin](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/containers/dlc-then-eks-devflow.html)
> will be automatically installed in the kubernetes cluster.
#### Node Groups with IPv6 Support
Expand Down
28 changes: 21 additions & 7 deletions packages/aws-cdk-lib/aws-eks/lib/cluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1800,7 +1800,8 @@ export class Cluster extends ClusterBase {
spotInterruptHandler: options.spotInterruptHandler,
});

if (nodeTypeForInstanceType(options.instanceType) === NodeType.INFERENTIA) {
if (nodeTypeForInstanceType(options.instanceType) === NodeType.INFERENTIA ||
nodeTypeForInstanceType(options.instanceType) === NodeType.TRAINIUM ) {
this.addNeuronDevicePlugin();
}

Expand All @@ -1817,11 +1818,13 @@ export class Cluster extends ClusterBase {
* @param options options for creating a new nodegroup
*/
public addNodegroupCapacity(id: string, options?: NodegroupOptions): Nodegroup {
const hasInferentiaInstanceType = [
const hasInferentiaOrTrainiumInstanceType = [
options?.instanceType,
...options?.instanceTypes ?? [],
].some(i => i && nodeTypeForInstanceType(i) === NodeType.INFERENTIA);
if (hasInferentiaInstanceType) {
].some(i => i && (nodeTypeForInstanceType(i) === NodeType.INFERENTIA ||
nodeTypeForInstanceType(i) === NodeType.TRAINIUM));

if (hasInferentiaOrTrainiumInstanceType) {
this.addNeuronDevicePlugin();
}
return new Nodegroup(this, `Nodegroup${id}`, {
Expand Down Expand Up @@ -2373,6 +2376,7 @@ export class EksOptimizedImage implements ec2.IMachineImage {
'amazon-linux-2/' : 'amazon-linux-2-arm64/' : '')
+ (this.nodeType === NodeType.GPU ? 'amazon-linux-2-gpu/' : '')
+ (this.nodeType === NodeType.INFERENTIA ? 'amazon-linux-2-gpu/' : '')
+ (this.nodeType === NodeType.TRAINIUM ? 'amazon-linux-2-gpu/' : '')
+ 'recommended/image_id';
}

Expand Down Expand Up @@ -2410,6 +2414,11 @@ export enum NodeType {
* Inferentia instances
*/
INFERENTIA = 'INFERENTIA',

/**
* Trainium instances
*/
TRAINIUM = 'TRAINIUM',
}

/**
Expand Down Expand Up @@ -2471,9 +2480,14 @@ export enum MachineImageType {
}

function nodeTypeForInstanceType(instanceType: ec2.InstanceType) {
return INSTANCE_TYPES.gpu.includes(instanceType.toString().substring(0, 2)) ? NodeType.GPU :
INSTANCE_TYPES.inferentia.includes(instanceType.toString().substring(0, 4)) ? NodeType.INFERENTIA :
NodeType.STANDARD;
if (INSTANCE_TYPES.gpu.includes(instanceType.toString().substring(0, 2))) {
return NodeType.GPU;
} else if (INSTANCE_TYPES.inferentia.includes(instanceType.toString().substring(0, 4))) {
return NodeType.INFERENTIA;
} else if (INSTANCE_TYPES.trainium.includes(instanceType.toString().substring(0, 4))) {
return NodeType.TRAINIUM;
}
return NodeType.STANDARD;
}

function cpuArchForInstanceType(instanceType: ec2.InstanceType) {
Expand Down
1 change: 1 addition & 0 deletions packages/aws-cdk-lib/aws-eks/lib/instance-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ export const INSTANCE_TYPES = {
graviton: ['a1'],
graviton2: ['c6g', 'm6g', 'r6g', 't4g'],
graviton3: ['c7g'],
trainium: ['trn1', 'trn1n'],
};
36 changes: 36 additions & 0 deletions packages/aws-cdk-lib/aws-eks/test/cluster.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2209,6 +2209,42 @@ describe('cluster', () => {
Manifest: JSON.stringify([sanitized]),
});
});
test('trn1 instances are supported', () => {
// GIVEN
const { stack } = testFixtureNoVpc();
const cluster = new eks.Cluster(stack, 'Cluster', { defaultCapacity: 0, version: CLUSTER_VERSION, prune: false });

// WHEN
cluster.addAutoScalingGroupCapacity('TrainiumInstances', {
instanceType: new ec2.InstanceType('trn1.2xlarge'),
minCapacity: 1,
});
const fileContents = fs.readFileSync(path.join(__dirname, '..', 'lib', 'addons', 'neuron-device-plugin.yaml'), 'utf8');
const sanitized = YAML.parse(fileContents);

// THEN
Template.fromStack(stack).hasResourceProperties(eks.KubernetesManifest.RESOURCE_TYPE, {
Manifest: JSON.stringify([sanitized]),
});
});
test('trn1n instances are supported', () => {
// GIVEN
const { stack } = testFixtureNoVpc();
const cluster = new eks.Cluster(stack, 'Cluster', { defaultCapacity: 0, version: CLUSTER_VERSION, prune: false });

// WHEN
cluster.addAutoScalingGroupCapacity('TrainiumInstances', {
instanceType: new ec2.InstanceType('trn1n.2xlarge'),
minCapacity: 1,
});
const fileContents = fs.readFileSync(path.join(__dirname, '..', 'lib', 'addons', 'neuron-device-plugin.yaml'), 'utf8');
const sanitized = YAML.parse(fileContents);

// THEN
Template.fromStack(stack).hasResourceProperties(eks.KubernetesManifest.RESOURCE_TYPE, {
Manifest: JSON.stringify([sanitized]),
});
});

test('inf1 instances are supported in addNodegroupCapacity', () => {
// GIVEN
Expand Down

0 comments on commit 507b709

Please sign in to comment.