diff --git a/v2/header.go b/v2/header.go index 6488461f..453fab7e 100644 --- a/v2/header.go +++ b/v2/header.go @@ -31,9 +31,15 @@ package gax import ( "bytes" + "context" + "fmt" + "net/http" "runtime" "strings" "unicode" + + "github.com/googleapis/gax-go/v2/callctx" + "google.golang.org/grpc/metadata" ) var ( @@ -117,3 +123,46 @@ func XGoogHeader(keyval ...string) string { } return buf.String()[1:] } + +// InsertMetadataIntoOutgoingContext is for use by the Google Cloud Libraries +// only. +// +// InsertMetadataIntoOutgoingContext returns a new context that merges the +// provided keyvals metadata pairs with any existing metadata/headers in the +// provided context. keyvals should have a corresponding value for every key +// provided. If there is an odd number of keyvals this method will panic. +// Existing values for keys will not be overwritten, instead provided values +// will be appended to the list of existing values. +func InsertMetadataIntoOutgoingContext(ctx context.Context, keyvals ...string) context.Context { + return metadata.NewOutgoingContext(ctx, insertMetadata(ctx, keyvals...)) +} + +// BuildHeaders is for use by the Google Cloud Libraries only. +// +// BuildHeaders returns a new http.Header that merges the provided +// keyvals header pairs with any existing metadata/headers in the provided +// context. keyvals should have a corresponding value for every key provided. +// If there is an odd number of keyvals this method will panic. +// Existing values for keys will not be overwritten, instead provided values +// will be appended to the list of existing values. +func BuildHeaders(ctx context.Context, keyvals ...string) http.Header { + return http.Header(insertMetadata(ctx, keyvals...)) +} + +func insertMetadata(ctx context.Context, keyvals ...string) metadata.MD { + if len(keyvals)%2 != 0 { + panic(fmt.Sprintf("gax: an even number of key value pairs must be provided, got %d", len(keyvals))) + } + out, ok := metadata.FromOutgoingContext(ctx) + if !ok { + out = metadata.MD(make(map[string][]string)) + } + headers := callctx.HeadersFromContext(ctx) + for k, v := range headers { + out[k] = append(out[k], v...) + } + for i := 0; i < len(keyvals); i = i + 2 { + out[keyvals[i]] = append(out[keyvals[i]], keyvals[i+1]) + } + return out +} diff --git a/v2/header_test.go b/v2/header_test.go index 94b5ccf8..74b2e046 100644 --- a/v2/header_test.go +++ b/v2/header_test.go @@ -30,9 +30,13 @@ package gax import ( + "context" + "net/http" "testing" "github.com/google/go-cmp/cmp" + "github.com/googleapis/gax-go/v2/callctx" + "google.golang.org/grpc/metadata" ) func TestXGoogHeader(t *testing.T) { @@ -89,3 +93,64 @@ func TestGoVersion(t *testing.T) { } } } + +func TestInsertMetadataIntoOutgoingContext(t *testing.T) { + for _, tst := range []struct { + // User-provided metadata set in context + userMd metadata.MD + // User-provided headers set in context + userHeaders []string + // Client-provided headers passed to func + clientHeaders []string + want metadata.MD + }{ + { + userMd: metadata.Pairs("key_1", "val_1", "key_2", "val_21"), + want: metadata.Pairs("key_1", "val_1", "key_2", "val_21"), + }, + { + userHeaders: []string{"key_2", "val_22"}, + want: metadata.Pairs("key_2", "val_22"), + }, + { + clientHeaders: []string{"key_2", "val_23", "key_2", "val_24"}, + want: metadata.Pairs("key_2", "val_23", "key_2", "val_24"), + }, + { + userMd: metadata.Pairs("key_1", "val_1", "key_2", "val_21"), + userHeaders: []string{"key_2", "val_22"}, + clientHeaders: []string{"key_2", "val_23", "key_2", "val_24"}, + want: metadata.Pairs("key_1", "val_1", "key_2", "val_21", "key_2", "val_22", "key_2", "val_23", "key_2", "val_24"), + }, + } { + ctx := context.Background() + if tst.userMd != nil { + ctx = metadata.NewOutgoingContext(ctx, tst.userMd) + } + ctx = callctx.SetHeaders(ctx, tst.userHeaders...) + + ctx = InsertMetadataIntoOutgoingContext(ctx, tst.clientHeaders...) + + got, _ := metadata.FromOutgoingContext(ctx) + if diff := cmp.Diff(tst.want, got); diff != "" { + t.Errorf("InsertMetadata(ctx, %q) mismatch (-want +got):\n%s", tst.clientHeaders, diff) + } + } +} + +func TestBuildHeaders(t *testing.T) { + // User-provided metadata set in context + existingMd := metadata.Pairs("key_1", "val_1", "key_2", "val_21") + ctx := metadata.NewOutgoingContext(context.Background(), existingMd) + // User-provided headers set in context + ctx = callctx.SetHeaders(ctx, "key_2", "val_22") + // Client-provided headers + keyvals := []string{"key_2", "val_23", "key_2", "val_24"} + + got := BuildHeaders(ctx, keyvals...) + + want := http.Header{"key_1": []string{"val_1"}, "key_2": []string{"val_21", "val_22", "val_23", "val_24"}} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("InsertMetadata(ctx, %q) mismatch (-want +got):\n%s", keyvals, diff) + } +}