diff --git a/core/appeal/service.go b/core/appeal/service.go index 2c9741e4..07d347b5 100644 --- a/core/appeal/service.go +++ b/core/appeal/service.go @@ -1489,9 +1489,14 @@ func (s *Service) GrantAccessToProvider(ctx context.Context, a *domain.Appeal, o } } + g := a.Grant + appealCopy := *a + appealCopy.Grant = nil + g.Appeal = &appealCopy if err := s.providerService.GrantAccess(ctx, *a.Grant); err != nil { return fmt.Errorf("granting access: %w", err) } + g.Appeal = nil return nil } diff --git a/plugins/providers/maxcompute/config.go b/plugins/providers/maxcompute/config.go index eb5a567a..c1d5f3f4 100644 --- a/plugins/providers/maxcompute/config.go +++ b/plugins/providers/maxcompute/config.go @@ -2,6 +2,7 @@ package maxcompute import ( "fmt" + "slices" "github.com/goto/guardian/domain" "github.com/goto/guardian/utils" @@ -14,6 +15,8 @@ const ( resourceTypeProject = "project" resourceTypeTable = "table" + + parameterRAMRoleKey = "ram_role" ) var ( @@ -71,5 +74,12 @@ func (c *config) validate() error { } } + // validate parameters + for _, param := range c.Parameters { + if !slices.Contains([]string{parameterRAMRoleKey}, param.Key) { + return fmt.Errorf("parameter key %q is not supported", param.Key) + } + } + return nil } diff --git a/plugins/providers/maxcompute/provider.go b/plugins/providers/maxcompute/provider.go index 193963a7..440bf6f8 100644 --- a/plugins/providers/maxcompute/provider.go +++ b/plugins/providers/maxcompute/provider.go @@ -5,7 +5,6 @@ import ( "slices" "strings" "sync" - "time" openapi "github.com/alibabacloud-go/darabonba-openapi/client" openapiv2 "github.com/alibabacloud-go/darabonba-openapi/v2/client" @@ -149,7 +148,8 @@ func (p *provider) GetResources(ctx context.Context, pc *domain.ProviderConfig) } func (p *provider) GrantAccess(ctx context.Context, pc *domain.ProviderConfig, g domain.Grant) error { - client, err := p.getOdpsClient(pc) + ramRole, _ := getParametersFromGrant[string](g, parameterRAMRoleKey) + client, err := p.getOdpsClient(pc, ramRole) if err != nil { return err } @@ -215,7 +215,8 @@ func (p *provider) GrantAccess(ctx context.Context, pc *domain.ProviderConfig, g } func (p *provider) RevokeAccess(ctx context.Context, pc *domain.ProviderConfig, g domain.Grant) error { - client, err := p.getOdpsClient(pc) + ramRole, _ := getParametersFromGrant[string](g, parameterRAMRoleKey) + client, err := p.getOdpsClient(pc, ramRole) if err != nil { return err } @@ -291,14 +292,14 @@ func (p *provider) getCreds(pc *domain.ProviderConfig) (*credentials, error) { return creds, nil } -func (p *provider) getClientConfig(providerURN string, creds *credentials) (*openapiv2.Config, error) { +func getClientConfig(providerURN, accountID, accountSecret, regionID, assumeAsRAMRole string) (*openapiv2.Config, error) { configV2 := &openapiv2.Config{ - AccessKeyId: &creds.AccessKeyID, - AccessKeySecret: &creds.AccessKeySecret, - Endpoint: &[]string{fmt.Sprintf("maxcompute.%s.aliyuncs.com", creds.RegionID)}[0], + AccessKeyId: &accountID, + AccessKeySecret: &accountSecret, + Endpoint: &[]string{fmt.Sprintf("maxcompute.%s.aliyuncs.com", regionID)}[0], } - if creds.RAMRole != "" { - stsEndpoint := fmt.Sprintf("sts.%s.aliyuncs.com", creds.RegionID) + if assumeAsRAMRole != "" { + stsEndpoint := fmt.Sprintf("sts.%s.aliyuncs.com", regionID) configV1 := &openapi.Config{ AccessKeyId: configV2.AccessKeyId, AccessKeySecret: configV2.AccessKeySecret, @@ -308,13 +309,12 @@ func (p *provider) getClientConfig(providerURN string, creds *credentials) (*ope if err != nil { return nil, fmt.Errorf("failed to initialize STS client: %w", err) } - sessionName := fmt.Sprintf("%s-%s", providerURN, time.Now().Format("2001-01-02T15:04:05")) res, err := stsClient.AssumeRole(&sts.AssumeRoleRequest{ - RoleArn: &creds.RAMRole, - RoleSessionName: &sessionName, + RoleArn: &assumeAsRAMRole, + RoleSessionName: &providerURN, }) if err != nil { - return nil, fmt.Errorf("failed to assume role %q: %w", creds.RAMRole, err) + return nil, fmt.Errorf("failed to assume role %q: %w", assumeAsRAMRole, err) } // TODO: handle refreshing token when the used one is expired @@ -335,7 +335,7 @@ func (p *provider) getRestClient(pc *domain.ProviderConfig) (*maxcompute.Client, if err != nil { return nil, err } - clientConfig, err := p.getClientConfig(pc.URN, creds) + clientConfig, err := getClientConfig(pc.URN, creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID, creds.RAMRole) if err != nil { return nil, err } @@ -350,8 +350,13 @@ func (p *provider) getRestClient(pc *domain.ProviderConfig) (*maxcompute.Client, return restClient, nil } -func (p *provider) getOdpsClient(pc *domain.ProviderConfig) (*odps.Odps, error) { - if client, ok := p.odpsClients[pc.URN]; ok { +func (p *provider) getOdpsClient(pc *domain.ProviderConfig, overrideRAMRole string) (*odps.Odps, error) { + usingRAMRole := overrideRAMRole != "" + if usingRAMRole { + if client, ok := p.odpsClients[overrideRAMRole]; ok { + return client, nil + } + } else if client, ok := p.odpsClients[pc.URN]; ok { return client, nil } @@ -359,7 +364,12 @@ func (p *provider) getOdpsClient(pc *domain.ProviderConfig) (*odps.Odps, error) if err != nil { return nil, err } - clientConfig, err := p.getClientConfig(pc.URN, creds) + ramRole := creds.RAMRole + if usingRAMRole { + ramRole = overrideRAMRole + } + + clientConfig, err := getClientConfig(pc.URN, creds.AccessKeyID, creds.AccessKeySecret, creds.RegionID, ramRole) if err != nil { return nil, err } @@ -373,7 +383,17 @@ func (p *provider) getOdpsClient(pc *domain.ProviderConfig) (*odps.Odps, error) client := odps.NewOdps(acc, endpoint) p.mu.Lock() - p.odpsClients[pc.URN] = client + if usingRAMRole { + p.odpsClients[overrideRAMRole] = client + } else { + p.odpsClients[pc.URN] = client + } p.mu.Unlock() return client, nil } + +func getParametersFromGrant[T any](g domain.Grant, key string) (T, bool) { + appealParams, _ := g.Appeal.Details[domain.ReservedDetailsKeyProviderParameters].(map[string]any) + value, ok := appealParams[key].(T) + return value, ok +}