Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: run local --use-task-role flag #5529

Merged
merged 16 commits into from
Dec 15, 2023
4 changes: 3 additions & 1 deletion internal/pkg/cli/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,11 @@ func (e *errTaskRoleRetrievalFailed) Error() string {
}

func (e *errTaskRoleRetrievalFailed) RecommendActions() string {
return fmt.Sprintf(`TaskRole retrieval failed. You can manually add permissions for your account to assume TaskRole by adding the following YAML override to your service:
return fmt.Sprintf(`TaskRole retrieval failed. If your containers don't require the TaskRole for local testing, you can use %s to disable this feature.
If you require the TaskRole, you can manually add permissions for your account to assume TaskRole by adding the following YAML override to your service:
%s
For more information on YAML overrides see %s`,
color.HighlightCode(`copilot run local --use-task-role=false`),
color.HighlightCodeBlock(`- op: add
path: /Resources/TaskRole/Properties/AssumeRolePolicyDocument/Statement/-
value:
Expand Down
144 changes: 141 additions & 3 deletions internal/pkg/cli/run_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@
package cli

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"os/signal"
Expand Down Expand Up @@ -65,6 +68,12 @@ const (
workloadAskPrompt = "Which workload would you like to run locally?"
)

const (
// Command to retrieve container credentials with ecs exec. See more at https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-iam-roles.html.
// Example output: {"AccessKeyId":"ACCESS_KEY_ID","Expiration":"EXPIRATION_DATE","RoleArn":"TASK_ROLE_ARN","SecretAccessKey":"SECRET_ACCESS_KEY","Token":"SECURITY_TOKEN_STRING"}
curlContainerCredentialsCmd = "curl 169.254.170.2$AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"
CaptainCarpensir marked this conversation as resolved.
Show resolved Hide resolved
)

type containerOrchestrator interface {
Start() <-chan error
RunTask(orchestrator.Task, ...orchestrator.RunTaskOption)
Expand Down Expand Up @@ -109,6 +118,7 @@ type runLocalOpts struct {

sel deploySelector
ecsClient ecsClient
ecsExecutor ecsCommandExecutor
ssm secretGetter
secretsManager secretGetter
sessProvider sessionProvider
Expand All @@ -133,6 +143,9 @@ type runLocalOpts struct {
labeledTermPrinter func(fw syncbuffer.FileWriter, bufs []*syncbuffer.LabeledSyncBuffer, opts ...syncbuffer.LabeledTermPrinterOption) clideploy.LabeledTermPrinter
unmarshal func([]byte) (manifest.DynamicWorkload, error)
newInterpolator func(app, env string) interpolator

captureStdout func() (io.Reader, error)
releaseStdout func()
}

func newRunLocalOpts(vars runLocalVars) (*runLocalOpts, error) {
Expand Down Expand Up @@ -184,6 +197,7 @@ func newRunLocalOpts(vars runLocalVars) (*runLocalOpts, error) {
// so use the default sess and *hope* they have permissions.
o.ecsClient = ecs.New(o.envManagerSess)
o.ssm = ssm.New(o.envManagerSess)
o.ecsExecutor = awsecs.New(o.envManagerSess)
o.secretsManager = secretsmanager.New(defaultSessEnvRegion)

resources, err := cloudformation.New(o.sess, cloudformation.WithProgressTracker(os.Stderr)).GetAppResourcesByRegion(o.targetApp, o.targetEnv.Region)
Expand Down Expand Up @@ -256,6 +270,32 @@ func newRunLocalOpts(vars runLocalVars) (*runLocalOpts, error) {
o.newRecursiveWatcher = func() (recursiveWatcher, error) {
return file.NewRecursiveWatcher(0)
}

// Capture stdout by replacing it with a piped writer and returning an attached io.Reader.
// Functions are concurrency safe and idempotent.
var mu sync.Mutex
var savedWriter, savedStdout *os.File
savedStdout = os.Stdout
o.captureStdout = func() (io.Reader, error) {
if savedWriter != nil {
savedWriter.Close()
}
pipeReader, pipeWriter, err := os.Pipe()
if err != nil {
return nil, err
}
mu.Lock()
savedWriter = pipeWriter
os.Stdout = savedWriter
mu.Unlock()
CaptainCarpensir marked this conversation as resolved.
Show resolved Hide resolved
return (io.Reader)(pipeReader), nil
}
o.releaseStdout = func() {
mu.Lock()
os.Stdout = savedStdout
mu.Unlock()
savedWriter.Close()
}
return o, nil
}

Expand Down Expand Up @@ -666,22 +706,119 @@ func (o *runLocalOpts) taskRoleCredentials(ctx context.Context) (map[string]stri

// ecsExecMethod tries to use ECS Exec to retrive credentials from running container
ecsExecMethod := func() (map[string]string, error) {
return nil, errors.New("ecs exec method not implemented")
svcDesc, err := o.ecsClient.DescribeService(o.appName, o.envName, o.wkldName)
if err != nil {
return nil, fmt.Errorf("describe ECS service for %s in environment %s: %w", o.wkldName, o.envName, err)
}

stdoutReader, err := o.captureStdout()
if err != nil {
return nil, err
}
defer o.releaseStdout()

// try exec on each container within the service
var wg sync.WaitGroup
containerErr := make(chan error)
for _, task := range svcDesc.Tasks {
taskID, err := awsecs.TaskID(aws.StringValue(task.TaskArn))
if err != nil {
return nil, err
}

for _, container := range task.Containers {
wg.Add(1)
containerName := aws.StringValue(container.Name)
go func() {
defer wg.Done()
containerErr <- o.ecsExecutor.ExecuteCommand(awsecs.ExecuteCommandInput{
CaptainCarpensir marked this conversation as resolved.
Show resolved Hide resolved
Cluster: svcDesc.ClusterName,
Command: fmt.Sprintf("/bin/sh -c %q\n", curlContainerCredentialsCmd),
Task: taskID,
Container: containerName,
})
}()
}
}

// wait for containers to finish and reset stdout
containersFinished := make(chan struct{})
go func() {
wg.Wait()
o.releaseStdout()
close(containersFinished)
}()

type containerCredentialsOutput struct {
AccessKeyId string
SecretAccessKey string
Token string
}

// parse stdout to try and find credentials
credsResult := make(chan map[string]string)
parseErr := make(chan error)
go func() {
select {
case <-containersFinished:
buf, err := io.ReadAll(stdoutReader)
if err != nil {
parseErr <- err
return
}
lines := bytes.Split(buf, []byte("\n"))
var creds containerCredentialsOutput
for _, line := range lines {
err := json.Unmarshal(line, &creds)
if err != nil {
continue
}
credsResult <- map[string]string{
"AWS_ACCESS_KEY_ID": creds.AccessKeyId,
"AWS_SECRET_ACCESS_KEY": creds.SecretAccessKey,
"AWS_SESSION_TOKEN": creds.Token,
}
return
}
parseErr <- errors.New("all containers failed to retrieve credentials")
case <-ctx.Done():
return
}
}()

var containerErrs []error
for {
select {
case creds := <-credsResult:
return creds, nil
case <-ctx.Done():
return nil, ctx.Err()
case err := <-parseErr:
return nil, errors.Join(append([]error{err}, containerErrs...)...)
case err := <-containerErr:
containerErrs = append(containerErrs, err)
}
}
}

credentialsChain := []func() (map[string]string, error){
assumeRoleMethod,
ecsExecMethod,
}

credentialsChainWrappedErrs := []string{
"assume role",
"ecs exec",
}

// return TaskRole credentials from first successful method
var errs []error
for _, method := range credentialsChain {
for errIndex, method := range credentialsChain {
vars, err := method()
if err == nil {
return vars, nil
}
errs = append(errs, err)
errs = append(errs, fmt.Errorf("%s: %w", credentialsChainWrappedErrs[errIndex], err))
}

return nil, &errTaskRoleRetrievalFailed{errs}
Expand Down Expand Up @@ -1056,6 +1193,7 @@ func BuildRunLocalCmd() *cobra.Command {
cmd.Flags().StringVarP(&vars.envName, envFlag, envFlagShort, "", envFlagDescription)
cmd.Flags().StringVarP(&vars.appName, appFlag, appFlagShort, tryReadingAppName(), appFlagDescription)
cmd.Flags().BoolVar(&vars.watch, watchFlag, false, watchFlagDescription)
cmd.Flags().BoolVar(&vars.useTaskRole, useTaskRoleFlag, true, useTaskRoleFlagDescription)
CaptainCarpensir marked this conversation as resolved.
Show resolved Hide resolved
cmd.Flags().Var(&vars.portOverrides, portOverrideFlag, portOverridesFlagDescription)
cmd.Flags().StringToStringVar(&vars.envOverrides, envVarOverrideFlag, nil, envVarOverrideFlagDescription)
cmd.Flags().BoolVar(&vars.proxy, proxyFlag, false, proxyFlagDescription)
Expand Down
Loading