Skip to content

Commit

Permalink
fix: downgrade Azure IMDS required version
Browse files Browse the repository at this point in the history
Fixes #8555

It seems that older version supports same set of fields we actually use
in our platform code, so we can safely downgrade to the version
supported by Azure Stack Hub.

I used
[this repo](https://github.com/Azure/azure-rest-api-specs/tree/main/specification/imds/data-plane/Microsoft.InstanceMetadataService/stable)
to check schemas across versions.

Signed-off-by: Andrey Smirnov <andrey.smirnov@siderolabs.com>
  • Loading branch information
smira committed Jun 3, 2024
1 parent 3086021 commit 9a23d84
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -293,16 +293,16 @@ func (a *Azure) configFromCD() ([]byte, error) {
//
//nolint:gocyclo
func (a *Azure) NetworkConfiguration(ctx context.Context, _ state.State, ch chan<- *runtime.PlatformNetworkConfig) error {
log.Printf("fetching azure instance config from: %q", AzureMetadataEndpoint)

metadata, err := a.getMetadata(ctx)
metadata, apiVersion, err := a.getMetadata(ctx)
if err != nil {
return err
}

log.Printf("fetching network config from %q", AzureInterfacesEndpoint)
interfacesEndpoint := fmt.Sprintf(AzureInterfacesEndpoint, apiVersion)

log.Printf("fetching network config from %q", interfacesEndpoint)

metadataNetworkConfig, err := download.Download(ctx, AzureInterfacesEndpoint,
metadataNetworkConfig, err := download.Download(ctx, interfacesEndpoint,
download.WithHeaders(map[string]string{"Metadata": "true"}))
if err != nil {
return fmt.Errorf("failed to fetch network config from metadata service: %w", err)
Expand All @@ -319,11 +319,13 @@ func (a *Azure) NetworkConfiguration(ctx context.Context, _ state.State, ch chan
return fmt.Errorf("failed to parse network metadata: %w", err)
}

log.Printf("fetching load balancer metadata from: %q", AzureLoadbalancerEndpoint)
loadbalancerEndpoint := fmt.Sprintf(AzureLoadbalancerEndpoint, apiVersion)

log.Printf("fetching load balancer metadata from: %q", loadbalancerEndpoint)

var loadBalancerAddresses LoadBalancerMetadata

lbConfig, err := download.Download(ctx, AzureLoadbalancerEndpoint,
lbConfig, err := download.Download(ctx, loadbalancerEndpoint,
download.WithHeaders(map[string]string{"Metadata": "true"}),
download.WithErrorOnNotFound(errors.ErrNoConfigSource),
download.WithErrorOnEmptyResponse(errors.ErrNoConfigSource))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"encoding/json"
stderrors "errors"
"fmt"
"log"

"github.com/siderolabs/talos/internal/app/machined/pkg/runtime/v1alpha1/platform/errors"
"github.com/siderolabs/talos/pkg/download"
)

Expand All @@ -19,15 +19,21 @@ const (
// ref: https://learn.microsoft.com/en-us/azure/virtual-machines/instance-metadata-service
// ref: https://github.com/Azure/azure-rest-api-specs/blob/main/specification/imds/data-plane/Microsoft.InstanceMetadataService/stable/2023-07-01/examples/GetInstanceMetadata.json

// AzureVersion is the version of the Azure metadata service.
AzureVersion = "2021-12-13"

// AzureVersionFallback is the fallback version of the Azure metadata service (e.g. Azure Stack Hub).
AzureVersionFallback = "2019-06-01"

// AzureInternalEndpoint is the Azure Internal Channel IP
// https://blogs.msdn.microsoft.com/mast/2015/05/18/what-is-the-ip-address-168-63-129-16/
AzureInternalEndpoint = "http://168.63.129.16"
// AzureMetadataEndpoint is the local endpoint for the metadata.
AzureMetadataEndpoint = "http://169.254.169.254/metadata/instance/compute?api-version=2021-12-13&format=json"
AzureMetadataEndpoint = "http://169.254.169.254/metadata/instance/compute?api-version=%s&format=json"
// AzureInterfacesEndpoint is the local endpoint to get external IPs.
AzureInterfacesEndpoint = "http://169.254.169.254/metadata/instance/network/interface?api-version=2021-12-13&format=json"
AzureInterfacesEndpoint = "http://169.254.169.254/metadata/instance/network/interface?api-version=%s&format=json"
// AzureLoadbalancerEndpoint is the local endpoint for load balancer config.
AzureLoadbalancerEndpoint = "http://169.254.169.254/metadata/loadbalancer?api-version=2021-05-01&format=json"
AzureLoadbalancerEndpoint = "http://169.254.169.254/metadata/loadbalancer?api-version=%s&format=json"

mnt = "/mnt"
)
Expand All @@ -54,18 +60,38 @@ type ComputeMetadata struct {
EvictionPolicy string `json:"evictionPolicy,omitempty"`
}

func (a *Azure) getMetadata(ctx context.Context) (*ComputeMetadata, error) {
metadataDl, err := download.Download(ctx, AzureMetadataEndpoint,
download.WithHeaders(map[string]string{"Metadata": "true"}))
if err != nil && !stderrors.Is(err, errors.ErrNoHostname) {
return nil, fmt.Errorf("error fetching metadata: %w", err)
func (a *Azure) getMetadata(ctx context.Context) (*ComputeMetadata, string, error) {
apiVersion := AzureVersion
errBadRequest := stderrors.New("bad request")

metadataEndpoint := fmt.Sprintf(AzureMetadataEndpoint, apiVersion)

log.Printf("fetching azure instance config from: %q", metadataEndpoint)

metadataDl, err := download.Download(ctx, metadataEndpoint,
download.WithHeaders(map[string]string{"Metadata": "true"}),
download.WithErrorOnBadRequest(errBadRequest),
)
if err != nil && stderrors.Is(err, errBadRequest) {
apiVersion = AzureVersionFallback
metadataEndpoint = fmt.Sprintf(AzureMetadataEndpoint, apiVersion)

log.Printf("fetching azure instance config from: %q", metadataEndpoint)

metadataDl, err = download.Download(ctx, metadataEndpoint,
download.WithHeaders(map[string]string{"Metadata": "true"}),
)
}

if err != nil {
return nil, "", fmt.Errorf("error fetching metadata: %w", err)
}

var metadata ComputeMetadata

if err = json.Unmarshal(metadataDl, &metadata); err != nil {
return nil, fmt.Errorf("failed to parse compute metadata: %w", err)
return nil, "", fmt.Errorf("failed to parse compute metadata: %w", err)
}

return &metadata, nil
return &metadata, apiVersion, nil
}
13 changes: 13 additions & 0 deletions pkg/download/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type downloadOptions struct {
EndpointFunc func(context.Context) (string, error)

ErrorOnNotFound error
ErrorOnBadRequest error
ErrorOnEmptyResponse error

Timeout time.Duration
Expand Down Expand Up @@ -108,6 +109,13 @@ func WithErrorOnEmptyResponse(e error) Option {
}
}

// WithErrorOnBadRequest provides specific error to return when response has HTTP 400 error.
func WithErrorOnBadRequest(e error) Option {
return func(d *downloadOptions) {
d.ErrorOnBadRequest = e
}
}

// WithEndpointFunc provides a function that sets the endpoint of the download options.
func WithEndpointFunc(endpointFunc func(context.Context) (string, error)) Option {
return func(d *downloadOptions) {
Expand Down Expand Up @@ -212,6 +220,7 @@ func Download(ctx context.Context, endpoint string, opts ...Option) (b []byte, e
return b, nil
}

//nolint:gocyclo
func download(req *http.Request, options *downloadOptions) (data []byte, err error) {
transport := httpdefaults.PatchTransport(cleanhttp.DefaultTransport())
transport.RegisterProtocol("tftp", NewTFTPTransport())
Expand Down Expand Up @@ -249,6 +258,10 @@ func download(req *http.Request, options *downloadOptions) (data []byte, err err
return data, options.ErrorOnNotFound
}

if resp.StatusCode == http.StatusBadRequest && options.ErrorOnBadRequest != nil {
return data, options.ErrorOnBadRequest
}

if resp.StatusCode != http.StatusOK {
// try to read first 32 bytes of the response body
// to provide more context in case of error
Expand Down
15 changes: 15 additions & 0 deletions pkg/download/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ func TestDownload(t *testing.T) {
case "/base64":
w.WriteHeader(http.StatusOK)
w.Write([]byte("ZGF0YQ==")) //nolint:errcheck
case "/400":
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintln(w, "bad request")
case "/404":
w.WriteHeader(http.StatusNotFound)
fmt.Fprintln(w, "not found")
Expand Down Expand Up @@ -107,12 +110,24 @@ func TestDownload(t *testing.T) {
opts: []download.Option{download.WithErrorOnNotFound(errors.New("gone forever"))},
expectedError: "gone forever",
},
{
name: "bad request error",
path: "/400",
opts: []download.Option{download.WithErrorOnBadRequest(errors.New("bad req"))},
expectedError: "bad req",
},
{
name: "failure 404",
path: "/404",
opts: []download.Option{download.WithTimeout(2 * time.Second)},
expectedError: "failed to download config, status code 404, body \"not found\\n\"",
},
{
name: "failure 400",
path: "/400",
opts: []download.Option{download.WithTimeout(2 * time.Second)},
expectedError: "failed to download config, status code 400, body \"bad request\\n\"",
},
{
name: "retry endpoint change",
opts: []download.Option{
Expand Down

0 comments on commit 9a23d84

Please sign in to comment.