From 3b99e1b4879f760613f9f2d955f23e33cd45b593 Mon Sep 17 00:00:00 2001 From: Bryce Carman Date: Thu, 8 Feb 2018 08:24:27 -0800 Subject: [PATCH] Add AWS cloud provider option for IAM role Currently the AWS cloud provider uses the EC2 instance role when interacting with AWS APIs. This change gives the option to provide and IAM role that the cloud provider will assume before calling the APIs. All resources created by the role will be owned by that account instead of the account where the EC2 instance is running. --- pkg/cloudprovider/providers/aws/BUILD | 2 + pkg/cloudprovider/providers/aws/aws.go | 61 ++++++++++++++----- pkg/cloudprovider/providers/aws/aws_test.go | 47 ++++++++------ .../providers/aws/regions_test.go | 2 +- pkg/cloudprovider/providers/aws/tags_test.go | 3 +- 5 files changed, 77 insertions(+), 38 deletions(-) diff --git a/pkg/cloudprovider/providers/aws/BUILD b/pkg/cloudprovider/providers/aws/BUILD index 7fb2d5e4eb5b4..4ce36c7227335 100644 --- a/pkg/cloudprovider/providers/aws/BUILD +++ b/pkg/cloudprovider/providers/aws/BUILD @@ -38,6 +38,7 @@ go_library( "//vendor/github.com/aws/aws-sdk-go/aws/awserr:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/credentials:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds:go_default_library", + "//vendor/github.com/aws/aws-sdk-go/aws/credentials/stscreds:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/ec2metadata:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/request:go_default_library", "//vendor/github.com/aws/aws-sdk-go/aws/session:go_default_library", @@ -46,6 +47,7 @@ go_library( "//vendor/github.com/aws/aws-sdk-go/service/elb:go_default_library", "//vendor/github.com/aws/aws-sdk-go/service/elbv2:go_default_library", "//vendor/github.com/aws/aws-sdk-go/service/kms:go_default_library", + "//vendor/github.com/aws/aws-sdk-go/service/sts:go_default_library", "//vendor/github.com/golang/glog:go_default_library", "//vendor/github.com/prometheus/client_golang/prometheus:go_default_library", "//vendor/gopkg.in/gcfg.v1:go_default_library", diff --git a/pkg/cloudprovider/providers/aws/aws.go b/pkg/cloudprovider/providers/aws/aws.go index 6e9b5ac718aff..31faa871d682d 100644 --- a/pkg/cloudprovider/providers/aws/aws.go +++ b/pkg/cloudprovider/providers/aws/aws.go @@ -33,6 +33,7 @@ import ( "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" @@ -41,6 +42,7 @@ import ( "github.com/aws/aws-sdk-go/service/elb" "github.com/aws/aws-sdk-go/service/elbv2" "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go/service/sts" "github.com/golang/glog" "github.com/prometheus/client_golang/prometheus" clientset "k8s.io/client-go/kubernetes" @@ -526,6 +528,9 @@ type CloudConfig struct { // RouteTableID enables using a specific RouteTable RouteTableID string + // RoleARN is the IAM role to assume when interaction with AWS APIs. + RoleARN string + // KubernetesClusterTag is the legacy cluster id we'll use to identify our cluster resources KubernetesClusterTag string // KubernetesClusterID is the cluster id we'll use to identify our cluster resources @@ -927,22 +932,43 @@ func (s *awsSdkEC2) DescribeVpcs(request *ec2.DescribeVpcsInput) (*ec2.DescribeV func init() { registerMetrics() cloudprovider.RegisterCloudProvider(ProviderName, func(config io.Reader) (cloudprovider.Interface, error) { + cfg, err := readAWSCloudConfig(config) + if err != nil { + return nil, fmt.Errorf("unable to read AWS cloud provider config file: %v", err) + } + + sess, err := session.NewSession(&aws.Config{}) + if err != nil { + return nil, fmt.Errorf("unable to initialize AWS session: %v", err) + } + + var provider credentials.Provider + if cfg.Global.RoleARN == "" { + provider = &ec2rolecreds.EC2RoleProvider{ + Client: ec2metadata.New(sess), + } + } else { + glog.Infof("Using AWS assumed role %v", cfg.Global.RoleARN) + provider = &stscreds.AssumeRoleProvider{ + Client: sts.New(sess), + RoleARN: cfg.Global.RoleARN, + } + } + creds := credentials.NewChainCredentials( []credentials.Provider{ &credentials.EnvProvider{}, - &ec2rolecreds.EC2RoleProvider{ - Client: ec2metadata.New(session.New(&aws.Config{})), - }, + provider, &credentials.SharedCredentialsProvider{}, }) aws := newAWSSDKProvider(creds) - return newAWSCloud(config, aws) + return newAWSCloud(*cfg, aws) }) } // readAWSCloudConfig reads an instance of AWSCloudConfig from config reader. -func readAWSCloudConfig(config io.Reader, metadata EC2Metadata) (*CloudConfig, error) { +func readAWSCloudConfig(config io.Reader) (*CloudConfig, error) { var cfg CloudConfig var err error @@ -953,20 +979,25 @@ func readAWSCloudConfig(config io.Reader, metadata EC2Metadata) (*CloudConfig, e } } + return &cfg, nil +} + +func updateConfigZone(cfg *CloudConfig, metadata EC2Metadata) error { if cfg.Global.Zone == "" { if metadata != nil { glog.Info("Zone not specified in configuration file; querying AWS metadata service") + var err error cfg.Global.Zone, err = getAvailabilityZone(metadata) if err != nil { - return nil, err + return err } } if cfg.Global.Zone == "" { - return nil, fmt.Errorf("no zone specified in configuration file") + return fmt.Errorf("no zone specified in configuration file") } } - return &cfg, nil + return nil } func getInstanceType(metadata EC2Metadata) (string, error) { @@ -989,7 +1020,7 @@ func azToRegion(az string) (string, error) { // newAWSCloud creates a new instance of AWSCloud. // AWSProvider and instanceId are primarily for tests -func newAWSCloud(config io.Reader, awsServices Services) (*Cloud, error) { +func newAWSCloud(cfg CloudConfig, awsServices Services) (*Cloud, error) { // We have some state in the Cloud object - in particular the attaching map // Log so that if we are building multiple Cloud objects, it is obvious! glog.Infof("Building AWS cloudprovider") @@ -999,9 +1030,9 @@ func newAWSCloud(config io.Reader, awsServices Services) (*Cloud, error) { return nil, fmt.Errorf("error creating AWS metadata client: %q", err) } - cfg, err := readAWSCloudConfig(config, metadata) + err = updateConfigZone(&cfg, metadata) if err != nil { - return nil, fmt.Errorf("unable to read AWS cloud provider config file: %v", err) + return nil, fmt.Errorf("unable to determine AWS zone from cloud provider config or EC2 instance metadata: %v", err) } zone := cfg.Global.Zone @@ -1059,7 +1090,7 @@ func newAWSCloud(config io.Reader, awsServices Services) (*Cloud, error) { asg: asg, metadata: metadata, kms: kms, - cfg: cfg, + cfg: &cfg, region: regionName, attaching: make(map[types.NodeName]map[mountDevice]awsVolumeID), @@ -1067,8 +1098,9 @@ func newAWSCloud(config io.Reader, awsServices Services) (*Cloud, error) { } awsCloud.instanceCache.cloud = awsCloud - if cfg.Global.VPC != "" && cfg.Global.SubnetID != "" && (cfg.Global.KubernetesClusterTag != "" || cfg.Global.KubernetesClusterID != "") { - // When the master is running on a different AWS account, cloud provider or on-premises + tagged := cfg.Global.KubernetesClusterTag != "" || cfg.Global.KubernetesClusterID != "" + if cfg.Global.VPC != "" && (cfg.Global.SubnetID != "" || cfg.Global.RoleARN != "") && tagged { + // When the master is running on a different AWS account, cloud provider or on-premise // build up a dummy instance and use the VPC from the nodes account glog.Info("Master is configured to run on a different AWS account, different cloud provider or on-premises") awsCloud.selfAWSInstance = &awsInstance{ @@ -1084,7 +1116,6 @@ func newAWSCloud(config io.Reader, awsServices Services) (*Cloud, error) { } awsCloud.selfAWSInstance = selfAWSInstance awsCloud.vpcID = selfAWSInstance.vpcID - } if cfg.Global.KubernetesClusterTag != "" || cfg.Global.KubernetesClusterID != "" { diff --git a/pkg/cloudprovider/providers/aws/aws_test.go b/pkg/cloudprovider/providers/aws/aws_test.go index f7cc8012dec3d..5e9601730b752 100644 --- a/pkg/cloudprovider/providers/aws/aws_test.go +++ b/pkg/cloudprovider/providers/aws/aws_test.go @@ -160,7 +160,10 @@ func TestReadAWSCloudConfig(t *testing.T) { if test.aws != nil { metadata, _ = test.aws.Metadata() } - cfg, err := readAWSCloudConfig(test.reader, metadata) + cfg, err := readAWSCloudConfig(test.reader) + if err == nil { + err = updateConfigZone(cfg, metadata) + } if test.expectError { if err == nil { t.Errorf("Should error for case %s (cfg=%v)", test.name, cfg) @@ -213,7 +216,11 @@ func TestNewAWSCloud(t *testing.T) { for _, test := range tests { t.Logf("Running test case %s", test.name) - c, err := newAWSCloud(test.reader, test.awsServices) + cfg, err := readAWSCloudConfig(test.reader) + var c *Cloud + if err == nil { + c, err = newAWSCloud(*cfg, test.awsServices) + } if test.expectError { if err == nil { t.Errorf("Should error for case %s", test.name) @@ -233,7 +240,7 @@ func mockInstancesResp(selfInstance *ec2.Instance, instances []*ec2.Instance) (* awsServices := newMockedFakeAWSServices(TestClusterId) awsServices.instances = instances awsServices.selfInstance = selfInstance - awsCloud, err := newAWSCloud(nil, awsServices) + awsCloud, err := newAWSCloud(CloudConfig{}, awsServices) if err != nil { panic(err) } @@ -242,7 +249,7 @@ func mockInstancesResp(selfInstance *ec2.Instance, instances []*ec2.Instance) (* func mockAvailabilityZone(availabilityZone string) *Cloud { awsServices := newMockedFakeAWSServices(TestClusterId).WithAz(availabilityZone) - awsCloud, err := newAWSCloud(nil, awsServices) + awsCloud, err := newAWSCloud(CloudConfig{}, awsServices) if err != nil { panic(err) } @@ -389,7 +396,7 @@ func TestGetRegion(t *testing.T) { func TestFindVPCID(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) if err != nil { t.Errorf("Error building aws cloud: %v", err) return @@ -463,7 +470,7 @@ func constructRouteTable(subnetID string, public bool) *ec2.RouteTable { func TestSubnetIDsinVPC(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) if err != nil { t.Errorf("Error building aws cloud: %v", err) return @@ -798,7 +805,7 @@ func TestFindInstanceByNodeNameExcludesTerminatedInstances(t *testing.T) { instances := []*ec2.Instance{&terminatedInstance, &runningInstance} awsServices.instances = append(awsServices.instances, instances...) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) if err != nil { t.Errorf("Error building aws cloud: %v", err) return @@ -818,7 +825,7 @@ func TestFindInstanceByNodeNameExcludesTerminatedInstances(t *testing.T) { func TestGetInstanceByNodeNameBatching(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) var tag ec2.Tag tag.Key = aws.String(TagNameKubernetesClusterPrefix + TestClusterId) @@ -845,7 +852,7 @@ func TestGetInstanceByNodeNameBatching(t *testing.T) { func TestGetVolumeLabels(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) volumeId := awsVolumeID("vol-VolumeId") expectedVolumeRequest := &ec2.DescribeVolumesInput{VolumeIds: []*string{volumeId.awsString()}} @@ -867,7 +874,7 @@ func TestGetVolumeLabels(t *testing.T) { func TestDescribeLoadBalancerOnDelete(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices) awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid") c.EnsureLoadBalancerDeleted(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}) @@ -875,7 +882,7 @@ func TestDescribeLoadBalancerOnDelete(t *testing.T) { func TestDescribeLoadBalancerOnUpdate(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices) awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid") c.UpdateLoadBalancer(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}, []*v1.Node{}) @@ -883,7 +890,7 @@ func TestDescribeLoadBalancerOnUpdate(t *testing.T) { func TestDescribeLoadBalancerOnGet(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices) awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid") c.GetLoadBalancer(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}) @@ -891,7 +898,7 @@ func TestDescribeLoadBalancerOnGet(t *testing.T) { func TestDescribeLoadBalancerOnEnsure(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices) awsServices.elb.(*MockedFakeELB).expectDescribeLoadBalancers("aid") c.EnsureLoadBalancer(context.TODO(), TestClusterName, &v1.Service{ObjectMeta: metav1.ObjectMeta{Name: "myservice", UID: "id"}}, []*v1.Node{}) @@ -1123,7 +1130,7 @@ func TestGetLoadBalancerAdditionalTags(t *testing.T) { func TestLBExtraSecurityGroupsAnnotation(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices) sg1 := "sg-000001" sg2 := "sg-000002" @@ -1159,7 +1166,7 @@ func TestLBExtraSecurityGroupsAnnotation(t *testing.T) { func TestAddLoadBalancerTags(t *testing.T) { loadBalancerName := "test-elb" awsServices := newMockedFakeAWSServices(TestClusterId) - c, _ := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, _ := newAWSCloud(CloudConfig{}, awsServices) want := make(map[string]string) want["tag1"] = "val1" @@ -1215,7 +1222,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) expectedHC := *defaultHC if test.overriddenFieldName != "" { // cater for test case with no overrides @@ -1233,7 +1240,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { t.Run("does not make an API call if the current health check is the same", func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) expectedHC := *defaultHC timeout := int64(3) @@ -1255,7 +1262,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { t.Run("validates resulting expected health check before making an API call", func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) expectedHC := *defaultHC invalidThreshold := int64(1) @@ -1271,7 +1278,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { t.Run("handles invalid override values", func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) annotations := map[string]string{ServiceAnnotationLoadBalancerHCTimeout: "3.3"} @@ -1283,7 +1290,7 @@ func TestEnsureLoadBalancerHealthCheck(t *testing.T) { t.Run("returns error when updating the health check fails", func(t *testing.T) { awsServices := newMockedFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) assert.Nil(t, err, "Error building aws cloud: %v", err) returnErr := fmt.Errorf("throttling error") awsServices.elb.(*MockedFakeELB).expectConfigureHealthCheck(&lbName, defaultHC, returnErr) diff --git a/pkg/cloudprovider/providers/aws/regions_test.go b/pkg/cloudprovider/providers/aws/regions_test.go index 50352f754f9ae..03fb8ff16abeb 100644 --- a/pkg/cloudprovider/providers/aws/regions_test.go +++ b/pkg/cloudprovider/providers/aws/regions_test.go @@ -74,7 +74,7 @@ func TestRecognizesNewRegion(t *testing.T) { } awsServices := NewFakeAWSServices(TestClusterId).WithAz(region + "a") - _, err := newAWSCloud(nil, awsServices) + _, err := newAWSCloud(CloudConfig{}, awsServices) if err != nil { t.Errorf("error building AWS cloud: %v", err) } diff --git a/pkg/cloudprovider/providers/aws/tags_test.go b/pkg/cloudprovider/providers/aws/tags_test.go index 42185a4f941db..c745451431b40 100644 --- a/pkg/cloudprovider/providers/aws/tags_test.go +++ b/pkg/cloudprovider/providers/aws/tags_test.go @@ -19,13 +19,12 @@ package aws import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" - "strings" "testing" ) func TestFilterTags(t *testing.T) { awsServices := NewFakeAWSServices(TestClusterId) - c, err := newAWSCloud(strings.NewReader("[global]"), awsServices) + c, err := newAWSCloud(CloudConfig{}, awsServices) if err != nil { t.Errorf("Error building aws cloud: %v", err) return