diff --git a/oci/auth/azure/auth.go b/oci/auth/azure/auth.go index 7a11ce97c..3f7975286 100644 --- a/oci/auth/azure/auth.go +++ b/oci/auth/azure/auth.go @@ -74,10 +74,10 @@ func (c *Client) getLoginAuth(ctx context.Context, registryURL string) (authn.Au c.credential = cred } + configurationEnvironment := getCloudConfiguration(registryURL) // Obtain access token using the token credential. - // TODO: Add support for other azure endpoints as well. armToken, err := c.credential.GetToken(ctx, policy.TokenRequestOptions{ - Scopes: []string{cloud.AzurePublic.Services[cloud.ResourceManager].Endpoint + "/" + ".default"}, + Scopes: []string{configurationEnvironment.Services[cloud.ResourceManager].Endpoint + "/" + ".default"}, }) if err != nil { return authConfig, err @@ -98,6 +98,19 @@ func (c *Client) getLoginAuth(ctx context.Context, registryURL string) (authn.Au }, nil } +// getCloudConfiguration returns the cloud configuration based on the registry URL. +// List from https://github.com/Azure/azure-sdk-for-go/blob/main/sdk/containers/azcontainerregistry/cloud_config.go#L16 +func getCloudConfiguration(url string) cloud.Configuration { + switch { + case strings.HasSuffix(url, ".azurecr.cn"): + return cloud.AzureChina + case strings.HasSuffix(url, ".azurecr.us"): + return cloud.AzureGovernment + default: + return cloud.AzurePublic + } +} + // ValidHost returns if a given host is a Azure container registry. // List from https://github.com/kubernetes/kubernetes/blob/v1.23.1/pkg/credentialprovider/azure/azure_credentials.go#L55 func ValidHost(host string) bool { diff --git a/oci/auth/azure/auth_test.go b/oci/auth/azure/auth_test.go index 21e1f9186..a1618dae7 100644 --- a/oci/auth/azure/auth_test.go +++ b/oci/auth/azure/auth_test.go @@ -26,6 +26,7 @@ import ( "testing" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" "github.com/google/go-containerregistry/pkg/authn" "github.com/google/go-containerregistry/pkg/name" . "github.com/onsi/gomega" @@ -113,6 +114,25 @@ func TestValidHost(t *testing.T) { } } +func TestGetCloudConfiguration(t *testing.T) { + tests := []struct { + host string + result cloud.Configuration + }{ + {"foo.azurecr.io", cloud.AzurePublic}, + {"foo.azurecr.cn", cloud.AzureChina}, + {"foo.azurecr.de", cloud.AzurePublic}, + {"foo.azurecr.us", cloud.AzureGovernment}, + } + + for _, tt := range tests { + t.Run(tt.host, func(t *testing.T) { + g := NewWithT(t) + g.Expect(getCloudConfiguration(tt.host)).To(Equal(tt.result)) + }) + } +} + func TestLogin(t *testing.T) { tests := []struct { name string