Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Workaround: auto-retry when looking for existing service principals, assuming that one is likely to exist when use_existing = true #1025

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions internal/services/serviceprincipals/service_principal_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ import (
"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
"github.com/manicminer/hamilton/msgraph"
"github.com/manicminer/hamilton/odata"

"github.com/hashicorp/terraform-provider-azuread/internal/clients"
"github.com/hashicorp/terraform-provider-azuread/internal/helpers"
"github.com/hashicorp/terraform-provider-azuread/internal/tf"
"github.com/hashicorp/terraform-provider-azuread/internal/utils"
"github.com/hashicorp/terraform-provider-azuread/internal/validate"
"github.com/manicminer/hamilton/msgraph"
"github.com/manicminer/hamilton/odata"
)

const servicePrincipalResourceName = "azuread_service_principal"
Expand Down Expand Up @@ -356,19 +355,20 @@ func servicePrincipalResourceCreate(ctx context.Context, d *schema.ResourceData,
callerId := meta.(*clients.Client).ObjectID

appId := d.Get("application_id").(string)
result, _, err := client.List(ctx, odata.Query{Filter: fmt.Sprintf("appId eq '%s'", appId)})

var servicePrincipal *msgraph.ServicePrincipal
var err error

if d.Get("use_existing").(bool) {
// Assume that a service principal already exists and try to look for it, whilst retrying to defeat eventual consistency
servicePrincipal, err = findByAppIdWithTimeout(ctx, 5*time.Minute, client, appId)
} else {
// Otherwise perform a single List operation to check for an existing service principal
servicePrincipal, err = findByAppId(ctx, client, appId)
}
if err != nil {
return tf.ErrorDiagF(err, "Could not list existing service principals")
}
var servicePrincipal *msgraph.ServicePrincipal
if result != nil {
for _, r := range *result {
if r.AppId != nil && strings.EqualFold(*r.AppId, appId) {
servicePrincipal = &r
break
}
}
}

if servicePrincipal != nil {
if servicePrincipal.ID() == nil || *servicePrincipal.ID() == "" {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,13 +644,17 @@ func (ServicePrincipalResource) fromApplicationTemplate(data acceptance.TestData
return fmt.Sprintf(`
provider "azuread" {}

data "azuread_client_config" "test" {}

resource "azuread_application" "test" {
display_name = "acctest-APP-%[1]d"
template_id = "%[2]s"
owners = [data.azuread_client_config.test.object_id]
}

resource "azuread_service_principal" "test" {
application_id = azuread_application.test.application_id
owners = [data.azuread_client_config.test.object_id]
use_existing = true
}
`, data.RandomInteger, testApplicationTemplateId)
Expand Down
90 changes: 89 additions & 1 deletion internal/services/serviceprincipals/serviceprincipals.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
package serviceprincipals

import (
"github.com/manicminer/hamilton/msgraph"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/hashicorp/terraform-provider-azuread/internal/utils"
"github.com/manicminer/hamilton/msgraph"
"github.com/manicminer/hamilton/odata"
)

func expandSamlSingleSignOn(in []interface{}) *msgraph.SamlSingleSignOnSettings {
Expand Down Expand Up @@ -33,3 +42,82 @@ func flattenSamlSingleSignOn(in *msgraph.SamlSingleSignOnSettings) []map[string]
"relay_state": relayState,
}}
}

func findByAppId(ctx context.Context, client *msgraph.ServicePrincipalsClient, appId string) (*msgraph.ServicePrincipal, error) {
var servicePrincipal *msgraph.ServicePrincipal

result, _, err := client.List(ctx, odata.Query{Filter: fmt.Sprintf("appId eq '%s'", appId)})
if err != nil {
return nil, fmt.Errorf("could not list existing service principals")
}
if result != nil {
for _, r := range *result {
if r.AppId != nil && strings.EqualFold(*r.AppId, appId) {
servicePrincipal = &r
break
}
}
}

return servicePrincipal, nil
}

func findByAppIdWithTimeout(ctx context.Context, timeout time.Duration, client *msgraph.ServicePrincipalsClient, appId string) (*msgraph.ServicePrincipal, error) {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()

unmarshal := func(resp *http.Response) (*msgraph.ServicePrincipal, error) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("listing service principals: %v", err)
}

// Close and reassign the response body
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewBuffer(respBody))

var data struct {
ServicePrincipals []msgraph.ServicePrincipal `json:"value"`
}
if err := json.Unmarshal(respBody, &data); err != nil {
return nil, fmt.Errorf("unmarshaling service principals: %v", err)
}

if len(data.ServicePrincipals) == 0 {
return nil, nil
} else if len(data.ServicePrincipals) > 1 {
return nil, fmt.Errorf("unexpected number of results, should have received 0 or 1, got %d", len(data.ServicePrincipals))
}

if data.ServicePrincipals[0].AppId == nil || !strings.EqualFold(*data.ServicePrincipals[0].AppId, appId) {
return nil, fmt.Errorf("returned service principal did not have a matching appId, expected %q, received %q", appId, *data.ServicePrincipals[0].AppId)
}

return &data.ServicePrincipals[0], nil
}

notReplicated := func(resp *http.Response, o *odata.OData) bool {
sp, err := unmarshal(resp)
if err == nil && sp == nil {
return false
}
return false
}

resp, _, _, err := client.BaseClient.Get(ctx, msgraph.GetHttpRequestInput{
ConsistencyFailureFunc: notReplicated,
DisablePaging: true,
OData: odata.Query{Filter: fmt.Sprintf("appId eq '%s'", appId)},
ValidStatusCodes: []int{http.StatusOK},
Uri: msgraph.Uri{
Entity: "/servicePrincipals",
HasTenantId: true,
},
})
if err != nil {
return nil, fmt.Errorf("listing service principals: %v", err)
}

return unmarshal(resp)
}