Skip to content

Commit

Permalink
Merge pull request #1025 from hashicorp/bugfix/existing-service-princ…
Browse files Browse the repository at this point in the history
…ipal-consistency-workaround

Workaround: auto-retry when looking for existing service principals, assuming that one is likely to exist when `use_existing = true`
  • Loading branch information
manicminer authored Feb 21, 2023
2 parents 34d062f + 18860da commit adc8bc5
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 14 deletions.
26 changes: 13 additions & 13 deletions internal/services/serviceprincipals/service_principal_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ import (
"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
"github.com/manicminer/hamilton/msgraph"
"github.com/manicminer/hamilton/odata"

"github.com/hashicorp/terraform-provider-azuread/internal/clients"
"github.com/hashicorp/terraform-provider-azuread/internal/helpers"
"github.com/hashicorp/terraform-provider-azuread/internal/tf"
"github.com/hashicorp/terraform-provider-azuread/internal/utils"
"github.com/hashicorp/terraform-provider-azuread/internal/validate"
"github.com/manicminer/hamilton/msgraph"
"github.com/manicminer/hamilton/odata"
)

const servicePrincipalResourceName = "azuread_service_principal"
Expand Down Expand Up @@ -356,19 +355,20 @@ func servicePrincipalResourceCreate(ctx context.Context, d *schema.ResourceData,
callerId := meta.(*clients.Client).ObjectID

appId := d.Get("application_id").(string)
result, _, err := client.List(ctx, odata.Query{Filter: fmt.Sprintf("appId eq '%s'", appId)})

var servicePrincipal *msgraph.ServicePrincipal
var err error

if d.Get("use_existing").(bool) {
// Assume that a service principal already exists and try to look for it, whilst retrying to defeat eventual consistency
servicePrincipal, err = findByAppIdWithTimeout(ctx, 5*time.Minute, client, appId)
} else {
// Otherwise perform a single List operation to check for an existing service principal
servicePrincipal, err = findByAppId(ctx, client, appId)
}
if err != nil {
return tf.ErrorDiagF(err, "Could not list existing service principals")
}
var servicePrincipal *msgraph.ServicePrincipal
if result != nil {
for _, r := range *result {
if r.AppId != nil && strings.EqualFold(*r.AppId, appId) {
servicePrincipal = &r
break
}
}
}

if servicePrincipal != nil {
if servicePrincipal.ID() == nil || *servicePrincipal.ID() == "" {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -644,13 +644,17 @@ func (ServicePrincipalResource) fromApplicationTemplate(data acceptance.TestData
return fmt.Sprintf(`
provider "azuread" {}
data "azuread_client_config" "test" {}
resource "azuread_application" "test" {
display_name = "acctest-APP-%[1]d"
template_id = "%[2]s"
owners = [data.azuread_client_config.test.object_id]
}
resource "azuread_service_principal" "test" {
application_id = azuread_application.test.application_id
owners = [data.azuread_client_config.test.object_id]
use_existing = true
}
`, data.RandomInteger, testApplicationTemplateId)
Expand Down
90 changes: 89 additions & 1 deletion internal/services/serviceprincipals/serviceprincipals.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
package serviceprincipals

import (
"github.com/manicminer/hamilton/msgraph"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"

"github.com/hashicorp/terraform-provider-azuread/internal/utils"
"github.com/manicminer/hamilton/msgraph"
"github.com/manicminer/hamilton/odata"
)

func expandSamlSingleSignOn(in []interface{}) *msgraph.SamlSingleSignOnSettings {
Expand Down Expand Up @@ -33,3 +42,82 @@ func flattenSamlSingleSignOn(in *msgraph.SamlSingleSignOnSettings) []map[string]
"relay_state": relayState,
}}
}

func findByAppId(ctx context.Context, client *msgraph.ServicePrincipalsClient, appId string) (*msgraph.ServicePrincipal, error) {
var servicePrincipal *msgraph.ServicePrincipal

result, _, err := client.List(ctx, odata.Query{Filter: fmt.Sprintf("appId eq '%s'", appId)})
if err != nil {
return nil, fmt.Errorf("could not list existing service principals")
}
if result != nil {
for _, r := range *result {
if r.AppId != nil && strings.EqualFold(*r.AppId, appId) {
servicePrincipal = &r
break
}
}
}

return servicePrincipal, nil
}

func findByAppIdWithTimeout(ctx context.Context, timeout time.Duration, client *msgraph.ServicePrincipalsClient, appId string) (*msgraph.ServicePrincipal, error) {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()

unmarshal := func(resp *http.Response) (*msgraph.ServicePrincipal, error) {
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("listing service principals: %v", err)
}

// Close and reassign the response body
resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewBuffer(respBody))

var data struct {
ServicePrincipals []msgraph.ServicePrincipal `json:"value"`
}
if err := json.Unmarshal(respBody, &data); err != nil {
return nil, fmt.Errorf("unmarshaling service principals: %v", err)
}

if len(data.ServicePrincipals) == 0 {
return nil, nil
} else if len(data.ServicePrincipals) > 1 {
return nil, fmt.Errorf("unexpected number of results, should have received 0 or 1, got %d", len(data.ServicePrincipals))
}

if data.ServicePrincipals[0].AppId == nil || !strings.EqualFold(*data.ServicePrincipals[0].AppId, appId) {
return nil, fmt.Errorf("returned service principal did not have a matching appId, expected %q, received %q", appId, *data.ServicePrincipals[0].AppId)
}

return &data.ServicePrincipals[0], nil
}

notReplicated := func(resp *http.Response, o *odata.OData) bool {
sp, err := unmarshal(resp)
if err == nil && sp == nil {
return false
}
return false
}

resp, _, _, err := client.BaseClient.Get(ctx, msgraph.GetHttpRequestInput{
ConsistencyFailureFunc: notReplicated,
DisablePaging: true,
OData: odata.Query{Filter: fmt.Sprintf("appId eq '%s'", appId)},
ValidStatusCodes: []int{http.StatusOK},
Uri: msgraph.Uri{
Entity: "/servicePrincipals",
HasTenantId: true,
},
})
if err != nil {
return nil, fmt.Errorf("listing service principals: %v", err)
}

return unmarshal(resp)
}

0 comments on commit adc8bc5

Please sign in to comment.