diff --git a/internal/services/serviceprincipals/service_principal_resource.go b/internal/services/serviceprincipals/service_principal_resource.go index 458718203..8e1fa8bd1 100644 --- a/internal/services/serviceprincipals/service_principal_resource.go +++ b/internal/services/serviceprincipals/service_principal_resource.go @@ -13,14 +13,13 @@ import ( "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" - "github.com/manicminer/hamilton/msgraph" - "github.com/manicminer/hamilton/odata" - "github.com/hashicorp/terraform-provider-azuread/internal/clients" "github.com/hashicorp/terraform-provider-azuread/internal/helpers" "github.com/hashicorp/terraform-provider-azuread/internal/tf" "github.com/hashicorp/terraform-provider-azuread/internal/utils" "github.com/hashicorp/terraform-provider-azuread/internal/validate" + "github.com/manicminer/hamilton/msgraph" + "github.com/manicminer/hamilton/odata" ) const servicePrincipalResourceName = "azuread_service_principal" @@ -356,19 +355,20 @@ func servicePrincipalResourceCreate(ctx context.Context, d *schema.ResourceData, callerId := meta.(*clients.Client).ObjectID appId := d.Get("application_id").(string) - result, _, err := client.List(ctx, odata.Query{Filter: fmt.Sprintf("appId eq '%s'", appId)}) + + var servicePrincipal *msgraph.ServicePrincipal + var err error + + if d.Get("use_existing").(bool) { + // Assume that a service principal already exists and try to look for it, whilst retrying to defeat eventual consistency + servicePrincipal, err = findByAppIdWithTimeout(ctx, 5*time.Minute, client, appId) + } else { + // Otherwise perform a single List operation to check for an existing service principal + servicePrincipal, err = findByAppId(ctx, client, appId) + } if err != nil { return tf.ErrorDiagF(err, "Could not list existing service principals") } - var servicePrincipal *msgraph.ServicePrincipal - if result != nil { - for _, r := range *result { - if r.AppId != nil && strings.EqualFold(*r.AppId, appId) { - servicePrincipal = &r - break - } - } - } if servicePrincipal != nil { if servicePrincipal.ID() == nil || *servicePrincipal.ID() == "" { diff --git a/internal/services/serviceprincipals/service_principal_resource_test.go b/internal/services/serviceprincipals/service_principal_resource_test.go index ad6b46409..679105d4c 100644 --- a/internal/services/serviceprincipals/service_principal_resource_test.go +++ b/internal/services/serviceprincipals/service_principal_resource_test.go @@ -644,13 +644,17 @@ func (ServicePrincipalResource) fromApplicationTemplate(data acceptance.TestData return fmt.Sprintf(` provider "azuread" {} +data "azuread_client_config" "test" {} + resource "azuread_application" "test" { display_name = "acctest-APP-%[1]d" template_id = "%[2]s" + owners = [data.azuread_client_config.test.object_id] } resource "azuread_service_principal" "test" { application_id = azuread_application.test.application_id + owners = [data.azuread_client_config.test.object_id] use_existing = true } `, data.RandomInteger, testApplicationTemplateId) diff --git a/internal/services/serviceprincipals/serviceprincipals.go b/internal/services/serviceprincipals/serviceprincipals.go index 7d6fbf24a..780fe31ee 100644 --- a/internal/services/serviceprincipals/serviceprincipals.go +++ b/internal/services/serviceprincipals/serviceprincipals.go @@ -1,9 +1,18 @@ package serviceprincipals import ( - "github.com/manicminer/hamilton/msgraph" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" "github.com/hashicorp/terraform-provider-azuread/internal/utils" + "github.com/manicminer/hamilton/msgraph" + "github.com/manicminer/hamilton/odata" ) func expandSamlSingleSignOn(in []interface{}) *msgraph.SamlSingleSignOnSettings { @@ -33,3 +42,82 @@ func flattenSamlSingleSignOn(in *msgraph.SamlSingleSignOnSettings) []map[string] "relay_state": relayState, }} } + +func findByAppId(ctx context.Context, client *msgraph.ServicePrincipalsClient, appId string) (*msgraph.ServicePrincipal, error) { + var servicePrincipal *msgraph.ServicePrincipal + + result, _, err := client.List(ctx, odata.Query{Filter: fmt.Sprintf("appId eq '%s'", appId)}) + if err != nil { + return nil, fmt.Errorf("could not list existing service principals") + } + if result != nil { + for _, r := range *result { + if r.AppId != nil && strings.EqualFold(*r.AppId, appId) { + servicePrincipal = &r + break + } + } + } + + return servicePrincipal, nil +} + +func findByAppIdWithTimeout(ctx context.Context, timeout time.Duration, client *msgraph.ServicePrincipalsClient, appId string) (*msgraph.ServicePrincipal, error) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + + unmarshal := func(resp *http.Response) (*msgraph.ServicePrincipal, error) { + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("listing service principals: %v", err) + } + + // Close and reassign the response body + resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewBuffer(respBody)) + + var data struct { + ServicePrincipals []msgraph.ServicePrincipal `json:"value"` + } + if err := json.Unmarshal(respBody, &data); err != nil { + return nil, fmt.Errorf("unmarshaling service principals: %v", err) + } + + if len(data.ServicePrincipals) == 0 { + return nil, nil + } else if len(data.ServicePrincipals) > 1 { + return nil, fmt.Errorf("unexpected number of results, should have received 0 or 1, got %d", len(data.ServicePrincipals)) + } + + if data.ServicePrincipals[0].AppId == nil || !strings.EqualFold(*data.ServicePrincipals[0].AppId, appId) { + return nil, fmt.Errorf("returned service principal did not have a matching appId, expected %q, received %q", appId, *data.ServicePrincipals[0].AppId) + } + + return &data.ServicePrincipals[0], nil + } + + notReplicated := func(resp *http.Response, o *odata.OData) bool { + sp, err := unmarshal(resp) + if err == nil && sp == nil { + return false + } + return false + } + + resp, _, _, err := client.BaseClient.Get(ctx, msgraph.GetHttpRequestInput{ + ConsistencyFailureFunc: notReplicated, + DisablePaging: true, + OData: odata.Query{Filter: fmt.Sprintf("appId eq '%s'", appId)}, + ValidStatusCodes: []int{http.StatusOK}, + Uri: msgraph.Uri{ + Entity: "/servicePrincipals", + HasTenantId: true, + }, + }) + if err != nil { + return nil, fmt.Errorf("listing service principals: %v", err) + } + + return unmarshal(resp) +}