Skip to content
This repository was archived by the owner on Nov 19, 2025. It is now read-only.

Commit 5d6fabd

Browse files
committed
Add GPU instance support to the up command
Refer to #729
1 parent c8b53d0 commit 5d6fabd

File tree

5 files changed

+210
-83
lines changed

5 files changed

+210
-83
lines changed

ecs-cli/modules/cli/cluster/cluster_app.go

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import (
1717
"bufio"
1818
"fmt"
1919
"os"
20-
"regexp"
2120
"strconv"
2221
"strings"
2322

@@ -72,10 +71,6 @@ const (
7271
ParameterKeySpotPrice = "SpotPrice"
7372
)
7473

75-
const (
76-
defaultARM64InstanceType = "a1.medium"
77-
)
78-
7974
var flagNamesToStackParameterKeys map[string]string
8075
var requiredParameters []string = []string{ParameterKeyCluster}
8176

@@ -317,20 +312,13 @@ func createCluster(context *cli.Context, awsClients *AWSClients, commandConfig *
317312
}
318313

319314
if launchType == config.LaunchTypeEC2 {
320-
architecture, err := determineArchitecture(cfnParams)
321-
if err != nil {
322-
return err
323-
}
324-
325315
// Check if image id was supplied, else populate
326316
_, err = cfnParams.GetParameter(ParameterKeyAmiId)
327317
if err == cloudformation.ParameterNotFoundError {
328-
amiMetadata, err := metadataClient.GetRecommendedECSLinuxAMI(architecture)
318+
err := populateAMIID(cfnParams, metadataClient)
329319
if err != nil {
330320
return err
331321
}
332-
logrus.Infof("Using recommended %s AMI with ECS Agent %s and %s", amiMetadata.OsName, amiMetadata.AgentVersion, amiMetadata.RuntimeVersion)
333-
cfnParams.Add(ParameterKeyAmiId, amiMetadata.ImageID)
334322
} else if err != nil {
335323
return err
336324
}
@@ -393,26 +381,33 @@ func canEnableContainerInstanceTagging(client ecsclient.ECSClient) (bool, error)
393381
return false, nil
394382
}
395383

396-
func determineArchitecture(cfnParams *cloudformation.CfnStackParams) (string, error) {
397-
architecture := amimetadata.ArchitectureTypeX86
384+
func retrieveInstanceType(cfnParams *cloudformation.CfnStackParams) (string, error) {
385+
param, err := cfnParams.GetParameter(ParameterKeyInstanceType)
398386

399-
// a1 instances get the Arm based ECS AMI
400-
instanceTypeParam, err := cfnParams.GetParameter(ParameterKeyInstanceType)
401387
if err == cloudformation.ParameterNotFoundError {
402-
logrus.Infof("Defaulting instance type to t2.micro")
403-
} else if err != nil {
388+
logrus.Infof("Defaulting instance type to %s", cloudformation.DefaultECSInstanceType)
389+
return cloudformation.DefaultECSInstanceType, nil
390+
}
391+
if err != nil {
404392
return "", err
405-
} else {
406-
instanceType := aws.StringValue(instanceTypeParam.ParameterValue)
407-
// This regex matches all current a1 instances, and should work for any future additions as well
408-
r := regexp.MustCompile("a1\\.(medium|\\d*x?large)")
409-
if r.MatchString(instanceType) {
410-
logrus.Infof("Using Arm ecs-optimized AMI because instance type was %s", instanceType)
411-
architecture = amimetadata.ArchitectureTypeARM64
412-
}
393+
}
394+
return aws.StringValue(param.ParameterValue), nil
395+
}
396+
397+
func populateAMIID(cfnParams *cloudformation.CfnStackParams, client amimetadata.Client) error {
398+
instanceType, err := retrieveInstanceType(cfnParams)
399+
if err != nil {
400+
return err
413401
}
414402

415-
return architecture, nil
403+
amiMetadata, err := client.GetRecommendedECSLinuxAMI(instanceType)
404+
if err != nil {
405+
return err
406+
}
407+
logrus.Infof("Using recommended %s AMI with ECS Agent %s and %s",
408+
amiMetadata.OsName, amiMetadata.AgentVersion, amiMetadata.RuntimeVersion)
409+
cfnParams.Add(ParameterKeyAmiId, amiMetadata.ImageID)
410+
return nil
416411
}
417412

418413
// unfortunately go SDK lacks a unified Tag type

ecs-cli/modules/cli/cluster/cluster_app_test.go

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ func TestClusterUpWithForce(t *testing.T) {
144144
)
145145

146146
gomock.InOrder(
147-
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("x86").Return(amiMetadata(amiID), nil),
147+
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("t2.micro").Return(amiMetadata(amiID), nil),
148148
)
149149

150150
gomock.InOrder(
@@ -179,7 +179,7 @@ func TestClusterUpWithoutPublicIP(t *testing.T) {
179179
)
180180

181181
gomock.InOrder(
182-
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("x86").Return(amiMetadata(amiID), nil),
182+
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("t2.micro").Return(amiMetadata(amiID), nil),
183183
)
184184

185185
gomock.InOrder(
@@ -232,7 +232,7 @@ func TestClusterUpWithUserData(t *testing.T) {
232232
)
233233

234234
gomock.InOrder(
235-
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("x86").Return(amiMetadata(amiID), nil),
235+
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("t2.micro").Return(amiMetadata(amiID), nil),
236236
)
237237

238238
gomock.InOrder(
@@ -281,7 +281,7 @@ func TestClusterUpWithSpotPrice(t *testing.T) {
281281
)
282282

283283
gomock.InOrder(
284-
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("x86").Return(amiMetadata(amiID), nil),
284+
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("t2.micro").Return(amiMetadata(amiID), nil),
285285
)
286286

287287
gomock.InOrder(
@@ -976,7 +976,7 @@ func TestClusterUpARM64(t *testing.T) {
976976
)
977977

978978
gomock.InOrder(
979-
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("arm64").Return(amiMetadata(armAMIID), nil),
979+
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("a1.medium").Return(amiMetadata(armAMIID), nil),
980980
)
981981

982982
gomock.InOrder(
@@ -1050,7 +1050,7 @@ func TestClusterUpWithTags(t *testing.T) {
10501050
}),
10511051
)
10521052
gomock.InOrder(
1053-
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("x86").Return(amiMetadata(amiID), nil),
1053+
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("t2.micro").Return(amiMetadata(amiID), nil),
10541054
)
10551055
gomock.InOrder(
10561056
mockCloudformation.EXPECT().ValidateStackExists(stackName).Return(errors.New("error")),
@@ -1131,7 +1131,7 @@ func TestClusterUpWithTagsContainerInstanceTaggingEnabled(t *testing.T) {
11311131
}),
11321132
)
11331133
gomock.InOrder(
1134-
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("x86").Return(amiMetadata(amiID), nil),
1134+
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("t2.micro").Return(amiMetadata(amiID), nil),
11351135
)
11361136
gomock.InOrder(
11371137
mockCloudformation.EXPECT().ValidateStackExists(stackName).Return(errors.New("error")),
@@ -1165,33 +1165,6 @@ func TestClusterUpWithTagsContainerInstanceTaggingEnabled(t *testing.T) {
11651165
assert.Equal(t, userdataMock.tags, expectedECSTags, "Expected tags to match")
11661166
}
11671167

1168-
func TestDetermineArchitecture(t *testing.T) {
1169-
var testCases = []struct {
1170-
in string
1171-
out string
1172-
}{
1173-
{"a1.medium", "arm64"},
1174-
{"a1.large", "arm64"},
1175-
{"a1.xlarge", "arm64"},
1176-
{"a1.2xlarge", "arm64"},
1177-
{"a1.4xlarge", "arm64"},
1178-
{"t2.medium", "x86"},
1179-
{"c5.large", "x86"},
1180-
{"i3.metal", "x86"},
1181-
{"t3.micro", "x86"},
1182-
}
1183-
1184-
for _, tt := range testCases {
1185-
t.Run(tt.in, func(t *testing.T) {
1186-
cfnParams := cloudformation.NewCfnStackParams(requiredParameters)
1187-
cfnParams.Add(ParameterKeyInstanceType, tt.in)
1188-
arch, err := determineArchitecture(cfnParams)
1189-
assert.NoError(t, err, "Unexpected error determining architecture")
1190-
assert.Equal(t, tt.out, arch, "Expected architecture to match")
1191-
})
1192-
}
1193-
}
1194-
11951168
///////////////////
11961169
// Cluster Down //
11971170
//////////////////
@@ -1397,7 +1370,7 @@ func mocksForSuccessfulClusterUp(mockECS *mock_ecs.MockECSClient, mockCloudforma
13971370
mockECS.EXPECT().CreateCluster(clusterName, gomock.Any()).Return(clusterName, nil),
13981371
)
13991372
gomock.InOrder(
1400-
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("x86").Return(amiMetadata(amiID), nil),
1373+
mockSSM.EXPECT().GetRecommendedECSLinuxAMI("t2.micro").Return(amiMetadata(amiID), nil),
14011374
)
14021375
gomock.InOrder(
14031376
mockCloudformation.EXPECT().ValidateStackExists(stackName).Return(errors.New("error")),

ecs-cli/modules/clients/aws/amimetadata/client.go

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,36 @@
1111
// express or implied. See the License for the specific language governing
1212
// permissions and limitations under the License.
1313

14+
// Package amimetadata provides AMI metadata given an instance type.
1415
package amimetadata
1516

1617
import (
1718
"encoding/json"
18-
1919
"github.com/aws/amazon-ecs-cli/ecs-cli/modules/clients"
2020
"github.com/aws/amazon-ecs-cli/ecs-cli/modules/config"
2121
"github.com/aws/aws-sdk-go/aws"
2222
"github.com/aws/aws-sdk-go/aws/awserr"
2323
"github.com/aws/aws-sdk-go/service/ssm"
2424
"github.com/aws/aws-sdk-go/service/ssm/ssmiface"
2525
"github.com/pkg/errors"
26+
"github.com/sirupsen/logrus"
27+
"regexp"
28+
"strings"
2629
)
2730

31+
// SSM parameter names to retrieve ECS optimized AMI.
32+
// See: https://docs.aws.amazon.com/AmazonECS/latest/developerguide/retrieve-ecs-optimized_AMI.html
2833
const (
29-
amazonLinux2X86RecommendedParameterName = "/aws/service/ecs/optimized-ami/amazon-linux-2/recommended"
30-
amazonLinux2ARM64RecommendedParameterName = "/aws/service/ecs/optimized-ami/amazon-linux-2/arm64/recommended"
31-
)
32-
33-
const (
34-
ArchitectureTypeARM64 = "arm64"
35-
ArchitectureTypeX86 = "x86"
34+
amazonLinux2X86RecommendedParameterName = "/aws/service/ecs/optimized-ami/amazon-linux-2/recommended"
35+
amazonLinux2ARM64RecommendedParameterName = "/aws/service/ecs/optimized-ami/amazon-linux-2/arm64/recommended"
36+
amazonLinux2X86GPURecommendedParameterName = "/aws/service/ecs/optimized-ami/amazon-linux-2/gpu/recommended"
3637
)
3738

39+
// AMIMetadata is returned through ssm:GetParameters and can be used to retrieve the ImageId
40+
// while launching instances.
41+
//
42+
// See: https://docs.aws.amazon.com/AmazonECS/latest/developerguide/retrieve-ecs-optimized_AMI.html
43+
// See: https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-properties-as-launchconfig.html#cfn-as-launchconfig-imageid
3844
type AMIMetadata struct {
3945
ImageID string `json:"image_id"`
4046
OsName string `json:"os"`
@@ -47,13 +53,13 @@ type Client interface {
4753
GetRecommendedECSLinuxAMI(string) (*AMIMetadata, error)
4854
}
4955

50-
// ssmClient implements Client
56+
// metadataClient implements Client.
5157
type metadataClient struct {
5258
client ssmiface.SSMAPI
5359
region string
5460
}
5561

56-
// NewSSMClient creates an instance of Client.
62+
// NewMetadataClient creates an instance of Client.
5763
func NewMetadataClient(commandConfig *config.CommandConfig) Client {
5864
client := ssm.New(commandConfig.Session)
5965
client.Handlers.Build.PushBackNamed(clients.CustomUserAgentHandler())
@@ -63,20 +69,31 @@ func NewMetadataClient(commandConfig *config.CommandConfig) Client {
6369
}
6470
}
6571

66-
func (c *metadataClient) GetRecommendedECSLinuxAMI(architecture string) (*AMIMetadata, error) {
67-
ssmParam := amazonLinux2X86RecommendedParameterName
68-
if architecture == ArchitectureTypeARM64 {
69-
ssmParam = amazonLinux2ARM64RecommendedParameterName
72+
// GetRecommendedECSLinuxAMI returns the recommended Amazon ECS-Optimized AMI Metadata given the instance type.
73+
func (c *metadataClient) GetRecommendedECSLinuxAMI(instanceType string) (*AMIMetadata, error) {
74+
if isARM64Instance(instanceType) {
75+
logrus.Infof("Using Arm ecs-optimized AMI because instance type was %s", instanceType)
76+
return c.parameterValueFor(amazonLinux2ARM64RecommendedParameterName)
77+
}
78+
if isGPUInstance(instanceType) {
79+
logrus.Infof("Using GPU ecs-optimized AMI because instance type was %s", instanceType)
80+
return c.parameterValueFor(amazonLinux2X86GPURecommendedParameterName)
7081
}
82+
return c.parameterValueFor(amazonLinux2X86RecommendedParameterName)
83+
}
7184

85+
func (c *metadataClient) parameterValueFor(ssmParamName string) (*AMIMetadata, error) {
7286
response, err := c.client.GetParameter(&ssm.GetParameterInput{
73-
Name: aws.String(ssmParam),
87+
Name: aws.String(ssmParamName),
7488
})
7589
if err != nil {
7690
if aerr, ok := err.(awserr.Error); ok {
7791
if aerr.Code() == ssm.ErrCodeParameterNotFound {
78-
// Added for arm AMIs which are only supported in some regions
79-
return nil, errors.Wrapf(err, "Could not find Recommended Amazon Linux 2 AMI in %s with architecture %s; the AMI may not be supported in this region", c.region, architecture)
92+
// Added for AMIs which are only supported in some regions
93+
return nil, errors.Wrapf(err,
94+
"Could not find Recommended Amazon Linux 2 AMI %s in %s; the AMI may not be supported in this region",
95+
ssmParamName,
96+
c.region)
8097
}
8198
}
8299
return nil, err
@@ -85,3 +102,24 @@ func (c *metadataClient) GetRecommendedECSLinuxAMI(architecture string) (*AMIMet
85102
err = json.Unmarshal([]byte(aws.StringValue(response.Parameter.Value)), metadata)
86103
return metadata, err
87104
}
105+
106+
func isARM64Instance(instanceType string) bool {
107+
r := regexp.MustCompile("a1\\.(medium|\\d*x?large)")
108+
if r.MatchString(instanceType) {
109+
return true
110+
}
111+
return false
112+
}
113+
114+
func isGPUInstance(instanceType string) bool {
115+
if strings.HasPrefix(instanceType, "p2.") {
116+
return true
117+
}
118+
if strings.HasPrefix(instanceType, "p3.") {
119+
return true
120+
}
121+
if strings.HasPrefix(instanceType, "p3dn.") {
122+
return true
123+
}
124+
return false
125+
}

0 commit comments

Comments
 (0)