diff --git a/sdk/internal/recording/recording.go b/sdk/internal/recording/recording.go index b947f37ec040..26d34cd57a07 100644 --- a/sdk/internal/recording/recording.go +++ b/sdk/internal/recording/recording.go @@ -7,10 +7,6 @@ package recording import ( - "bytes" - "crypto/tls" - "crypto/x509" - "encoding/json" "errors" "fmt" "io/ioutil" @@ -20,7 +16,6 @@ import ( "path/filepath" "strconv" "strings" - "testing" "time" "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid" @@ -40,6 +35,7 @@ type Recording struct { src rand.Source now *time.Time Sanitizer *Sanitizer + Matcher *RequestMatcher c TestContext } @@ -69,8 +65,11 @@ const ( type VariableType string const ( - Default VariableType = "default" - Secret_String VariableType = "secret_string" + // NoSanitization indicates that the recorded value should not be sanitized. + NoSanitization VariableType = "default" + // Secret_String indicates that the recorded value should be replaced with a sanitized value. + Secret_String VariableType = "secret_string" + // Secret_Base64String indicates that the recorded value should be replaced with a sanitized valid base-64 string value. Secret_Base64String VariableType = "secret_base64String" ) @@ -107,17 +106,18 @@ func NewRecording(c TestContext, mode RecordMode) (*Recording, error) { } // set the recorder Matcher + recording.Matcher = defaultMatcher(c) rec.SetMatcher(recording.matchRequest) // wire up the sanitizer - recording.Sanitizer = DefaultSanitizer(rec) + recording.Sanitizer = defaultSanitizer(rec) return recording, err } -// GetRecordedVariable returns a recorded variable. If the variable is not found we return an error -// variableType determines how the recorded variable will be saved. Default indicates that the value should be saved without any sanitation. -func (r *Recording) GetRecordedVariable(name string, variableType VariableType) (string, error) { +// GetEnvVar returns a recorded environment variable. If the variable is not found we return an error. +// variableType determines how the recorded variable will be saved. +func (r *Recording) GetEnvVar(name string, variableType VariableType) (string, error) { var err error result, ok := r.previousSessionVariables[name] if !ok || r.Mode == Live { @@ -132,9 +132,10 @@ func (r *Recording) GetRecordedVariable(name string, variableType VariableType) return *result, err } -// GetOptionalRecordedVariable returns a recorded variable with a fallback default value -// variableType determines how the recorded variable will be saved. Default indicates that the value should be saved without any sanitation. -func (r *Recording) GetOptionalRecordedVariable(name string, defaultValue string, variableType VariableType) string { +// GetOptionalEnvVar returns a recorded environment variable with a fallback default value. +// default Value configures the fallback value to be returned if the environment variable is not set. +// variableType determines how the recorded variable will be saved. +func (r *Recording) GetOptionalEnvVar(name string, defaultValue string, variableType VariableType) string { result, ok := r.previousSessionVariables[name] if !ok || r.Mode == Live { result = getOptionalEnv(name, defaultValue) @@ -280,10 +281,10 @@ func getOptionalEnv(name string, defaultValue string) *string { } func (r *Recording) matchRequest(req *http.Request, rec cassette.Request) bool { - isMatch := compareMethods(req, rec, r.c) && - compareURLs(req, rec, r.c) && - compareHeaders(req, rec, r.c) && - compareBodies(req, rec, r.c) + isMatch := r.Matcher.compareMethods(req, rec.Method) && + r.Matcher.compareURLs(req, rec.URL) && + r.Matcher.compareHeaders(req, rec) && + r.Matcher.compareBodies(req, rec.Body) return isMatch } @@ -432,272 +433,3 @@ var modeMap = map[RecordMode]recorder.Mode{ Live: recorder.ModeDisabled, Playback: recorder.ModeReplaying, } - -var recordMode, _ = os.LookupEnv("AZURE_RECORD_MODE") -var ModeRecording = "record" -var ModePlayback = "playback" - -var baseProxyURLSecure = "localhost:5001" -var baseProxyURL = "localhost:5000" -var startURL = baseProxyURLSecure + "/record/start" -var stopURL = baseProxyURLSecure + "/record/stop" - -var recordingId string -var IdHeader = "x-recording-id" -var ModeHeader = "x-recording-mode" -var UpstreamUriHeader = "x-recording-upstream-base-uri" - -var tr = &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, -} -var client = http.Client{ - Transport: tr, -} - -type RecordingOptions struct { - MaxRetries int32 - UseHTTPS bool - Host string - Scheme string -} - -func defaultOptions() *RecordingOptions { - return &RecordingOptions{ - MaxRetries: 0, - UseHTTPS: true, - Host: "localhost:5001", - Scheme: "https", - } -} - -func (r RecordingOptions) HostScheme() string { - if r.UseHTTPS { - return "https://localhost:5001" - } - return "http://localhost:5000" -} - -func getTestId(t *testing.T) string { - cwd, err := os.Getwd() - if err != nil { - t.Errorf("Could not find current working directory") - } - cwd = "./recordings/" + t.Name() + ".json" - return cwd -} - -func StartRecording(t *testing.T, options *RecordingOptions) error { - if options == nil { - options = defaultOptions() - } - if recordMode == "" { - t.Log("AZURE_RECORD_MODE was not set, options are \"record\" or \"playback\". \nDefaulting to playback") - recordMode = "playback" - } else { - t.Log("AZURE_RECORD_MODE: ", recordMode) - } - testId := getTestId(t) - - url := fmt.Sprintf("%v/%v/start", options.HostScheme(), recordMode) - - req, err := http.NewRequest("POST", url, nil) - if err != nil { - return err - } - - req.Header.Set("x-recording-file", testId) - - resp, err := client.Do(req) - if err != nil { - return err - } - recordingId = resp.Header.Get(IdHeader) - return nil -} - -func StopRecording(t *testing.T, options *RecordingOptions) error { - if options == nil { - options = defaultOptions() - } - - url := fmt.Sprintf("%v/%v/stop", options.HostScheme(), recordMode) - req, err := http.NewRequest("POST", url, nil) - if err != nil { - return err - } - if recordingId == "" { - return errors.New("Recording ID was never set. Did you call StartRecording?") - } - req.Header.Set("x-recording-id", recordingId) - _, err = client.Do(req) - if err != nil { - t.Errorf(err.Error()) - } - return nil -} - -func AddUriSanitizer(replacement, regex string, options *RecordingOptions) error { - if options == nil { - options = defaultOptions() - } - url := fmt.Sprintf("%v/Admin/AddSanitizer", options.HostScheme()) - req, err := http.NewRequest("POST", url, nil) - if err != nil { - return err - } - req.Header.Set("x-abstraction-identifier", "UriRegexSanitizer") - bodyContent := map[string]string{ - "value": replacement, - "regex": regex, - } - marshalled, err := json.Marshal(bodyContent) - if err != nil { - return err - } - req.Body = ioutil.NopCloser(bytes.NewReader(marshalled)) - req.ContentLength = int64(len(marshalled)) - _, err = client.Do(req) - return err -} - -func (o *RecordingOptions) Init() { - if o.MaxRetries != 0 { - o.MaxRetries = 0 - } - if o.UseHTTPS { - o.Host = baseProxyURLSecure - o.Scheme = "https" - } else { - o.Host = baseProxyURL - o.Scheme = "http" - } -} - -// type recordingPolicy struct { -// options RecordingOptions -// } - -// func NewRecordingPolicy(o *RecordingOptions) azcore.Policy { -// if o == nil { -// o = &RecordingOptions{} -// } -// p := &recordingPolicy{options: *o} -// p.options.init() -// return p -// } - -// func (p *recordingPolicy) Do(req *azcore.Request) (resp *azcore.Response, err error) { -// originalURLHost := req.URL.Host -// req.URL.Scheme = "https" -// req.URL.Host = p.options.host -// req.Host = p.options.host - -// req.Header.Set(UpstreamUriHeader, fmt.Sprintf("%v://%v", p.options.scheme, originalURLHost)) -// req.Header.Set(ModeHeader, recordMode) -// req.Header.Set(recordingIdHeader, recordingId) - -// return req.Next() -// } - -// This looks up an environment variable and if it is not found, returns the recordedValue -func GetEnvVariable(t *testing.T, varName string, recordedValue string) string { - val, ok := os.LookupEnv(varName) - if !ok { - t.Logf("Could not find environment variable: %v", varName) - return recordedValue - } - return val -} - -func LiveOnly(t *testing.T) { - if GetRecordMode() != ModeRecording { - t.Skip("Live Test Only") - } -} - -// Function for sleeping during a test for `duration` seconds. This method will only execute when -// AZURE_RECORD_MODE = "record", if a test is running in playback this will be a noop. -func Sleep(duration int) { - if GetRecordMode() == ModeRecording { - time.Sleep(time.Duration(duration) * time.Second) - } -} - -func GetRecordingId() string { - return recordingId -} - -func GetRecordMode() string { - return recordMode -} - -func InPlayback() bool { - return GetRecordMode() == ModePlayback -} - -func InRecord() bool { - return GetRecordMode() == ModeRecording -} - -// type FakeCredential struct { -// accountName string -// accountKey string -// } - -// func NewFakeCredential(accountName, accountKey string) *FakeCredential { -// return &FakeCredential{ -// accountName: accountName, -// accountKey: accountKey, -// } -// } - -// func (f *FakeCredential) AuthenticationPolicy(azcore.AuthenticationPolicyOptions) azcore.Policy { -// return azcore.PolicyFunc(func(req *azcore.Request) (*azcore.Response, error) { -// authHeader := strings.Join([]string{"Authorization ", f.accountName, ":", f.accountKey}, "") -// req.Request.Header.Set(azcore.HeaderAuthorization, authHeader) -// return req.Next() -// }) -// } - -func getRootCas() (*x509.CertPool, error) { - localFile, ok := os.LookupEnv("PROXY_CERT") - - rootCAs, err := x509.SystemCertPool() - if err != nil { - rootCAs = x509.NewCertPool() - } - - if !ok { - fmt.Println("Could not find path to proxy certificate, set the environment variable 'PROXY_CERT' to the location of your certificate") - return rootCAs, nil - } - - cert, err := ioutil.ReadFile(*&localFile) - if err != nil { - fmt.Println("error opening cert file") - return nil, err - } - - if ok := rootCAs.AppendCertsFromPEM(cert); !ok { - fmt.Println("No certs appended, using system certs only") - } - - return rootCAs, nil -} - -func GetHTTPClient() (*http.Client, error) { - transport := http.DefaultTransport.(*http.Transport).Clone() - - rootCAs, err := getRootCas() - if err != nil { - return nil, err - } - - transport.TLSClientConfig.RootCAs = rootCAs - transport.TLSClientConfig.MinVersion = tls.VersionTLS12 - - defaultHttpClient := &http.Client{ - Transport: transport, - } - return defaultHttpClient, nil -} diff --git a/sdk/internal/recording/recording_test.go b/sdk/internal/recording/recording_test.go index 7d823939b891..b7d39956bb18 100644 --- a/sdk/internal/recording/recording_test.go +++ b/sdk/internal/recording/recording_test.go @@ -75,14 +75,12 @@ func (s *recordingTests) TestRecordedVariables() { // non existent variables return an error _, err = target.GetEnvVar(nonExistingEnvVar, NoSanitization) - // mark test as succeeded require.Equal(envNotExistsError(nonExistingEnvVar), err.Error()) // now create the env variable and check that it can be fetched os.Setenv(nonExistingEnvVar, expectedVariableValue) defer os.Unsetenv(nonExistingEnvVar) - val, err := target.GetEnvVar(nonExistingEnvVar, NoSanitization) require.NoError(err) require.Equal(expectedVariableValue, val) @@ -112,7 +110,6 @@ func (s *recordingTests) TestRecordedVariablesSanitized() { require.NoError(err) // call GetOptionalRecordedVariable with the Secret_String VariableType arg - require.Equal(secret, target.GetOptionalEnvVar(SanitizedStringVar, secret, Secret_String)) // call GetOptionalRecordedVariable with the Secret_Base64String VariableType arg @@ -151,7 +148,7 @@ func (s *recordingTests) TestStopSavesVariablesIfExistAndReadsPreviousVariables( target, err := NewRecording(context, Playback) require.NoError(err) - target.GetOptionalRecordedVariable(expectedVariableName, expectedVariableValue, Default) + target.GetOptionalEnvVar(expectedVariableName, expectedVariableValue, NoSanitization) err = target.Stop() require.NoError(err) @@ -168,7 +165,7 @@ func (s *recordingTests) TestStopSavesVariablesIfExistAndReadsPreviousVariables( require.NoError(err) // add a new variable to the existing batch - target2.GetOptionalRecordedVariable(addedVariableName, addedVariableValue, Default) + target2.GetOptionalEnvVar(addedVariableName, addedVariableValue, NoSanitization) err = target2.Stop() require.NoError(err) diff --git a/sdk/internal/recording/request_matcher.go b/sdk/internal/recording/request_matcher.go index 09301b86f68d..3ff392b955f0 100644 --- a/sdk/internal/recording/request_matcher.go +++ b/sdk/internal/recording/request_matcher.go @@ -17,7 +17,13 @@ import ( ) type RequestMatcher struct { - ignoredHeaders map[string]*string + context TestContext + // IgnoredHeaders is a map acting as a hash set of the header names that will be ignored for matching. + // Modifying the keys in the map will affect how headers are matched for recordings. + IgnoredHeaders map[string]struct{} + bodyMatcher StringMatcher + urlMatcher StringMatcher + methodMatcher StringMatcher } type StringMatcher func(reqVal string, recVal string) bool @@ -95,46 +101,46 @@ func (m *RequestMatcher) SetMethodMatcher(matcher StringMatcher) { } } -var recordingHeaderMissing = "Test recording headers do not match. Header '%s' is present in request but not in recording." -var requestHeaderMissing = "Test recording headers do not match. Header '%s' is present in recording but not in request." -var headerValuesMismatch = "Test recording header '%s' does not match. request: %s, recording: %s" -var methodMismatch = "Test recording methods do not match. request: %s, recording: %s" -var urlMismatch = "Test recording URLs do not match. request: %s, recording: %s" -var bodiesMismatch = "Test recording bodies do not match.\nrequest: %s\nrecording: %s" +func defaultStringMatcher(s1 string, s2 string) bool { + return s1 == s2 +} -func compareBodies(r *http.Request, i cassette.Request, c TestContext) bool { +func getBody(r *http.Request) string { body := bytes.Buffer{} if r.Body != nil { _, err := body.ReadFrom(r.Body) if err != nil { - return false + return "could not parse body: " + err.Error() } r.Body = ioutil.NopCloser(&body) } - bodiesMatch := body.String() == i.Body - if !bodiesMatch { - c.Log(fmt.Sprintf(bodiesMismatch, body.String(), i.Body)) - } - return bodiesMatch + return body.String() } -func compareURLs(r *http.Request, i cassette.Request, c TestContext) bool { - if r.URL.String() != i.URL { - c.Log(fmt.Sprintf(urlMismatch, r.URL.String(), i.URL)) - return false - } - return true +func getUrl(r *http.Request) string { + return r.URL.String() } -func compareMethods(r *http.Request, i cassette.Request, c TestContext) bool { - if r.Method != i.Method { - c.Log(fmt.Sprintf(methodMismatch, r.Method, i.Method)) - return false - } - return true +func getMethod(r *http.Request) string { + return r.Method +} + +func (m *RequestMatcher) compareBodies(r *http.Request, recordedBody string) bool { + body := getBody(r) + return m.bodyMatcher(body, recordedBody) +} + +func (m *RequestMatcher) compareURLs(r *http.Request, recordedUrl string) bool { + url := getUrl(r) + return m.urlMatcher(url, recordedUrl) +} + +func (m *RequestMatcher) compareMethods(r *http.Request, recordedMethod string) bool { + method := getMethod(r) + return m.methodMatcher(method, recordedMethod) } -func compareHeaders(r *http.Request, i cassette.Request, c TestContext) bool { +func (m *RequestMatcher) compareHeaders(r *http.Request, i cassette.Request) bool { unVisitedCassetteKeys := make(map[string]*string, len(i.Headers)) // clone the cassette keys to track which we have seen for k := range i.Headers { @@ -155,20 +161,20 @@ func compareHeaders(r *http.Request, i cassette.Request, c TestContext) bool { headersMatch := reflect.DeepEqual(requestHeader, recordedHeader) if !headersMatch { // headers don't match - c.Log(fmt.Sprintf(headerValuesMismatch, key, requestHeader, recordedHeader)) + m.context.Log(fmt.Sprintf(headerValuesMismatch, key, requestHeader, recordedHeader)) return false } } else { // header not found - c.Log(fmt.Sprintf(recordingHeaderMissing, key)) + m.context.Log(fmt.Sprintf(recordingHeaderMissing, key)) return false } } if len(unVisitedCassetteKeys) > 0 { // headers exist in the recording that do not exist in the request for headerName := range unVisitedCassetteKeys { - c.Log(fmt.Sprintf(requestHeaderMissing, headerName)) + m.context.Log(fmt.Sprintf(requestHeaderMissing, headerName)) } return false } diff --git a/sdk/internal/recording/request_matcher_test.go b/sdk/internal/recording/request_matcher_test.go index 4eb985fe57d8..c65eac65a3e0 100644 --- a/sdk/internal/recording/request_matcher_test.go +++ b/sdk/internal/recording/request_matcher_test.go @@ -34,18 +34,19 @@ const unMatchedBody string = "This body does not match." func (s *requestMatcherTests) TestCompareBodies() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + matcher := defaultMatcher(context) req := http.Request{Body: closerFromString(matchedBody)} recReq := cassette.Request{Body: matchedBody} - isMatch := compareBodies(&req, recReq, context) + isMatch := matcher.compareBodies(&req, recReq.Body) assert.Equal(true, isMatch) // make the requests mis-match req.Body = closerFromString((unMatchedBody)) - isMatch = compareBodies(&req, recReq, context) + isMatch = matcher.compareBodies(&req, recReq.Body) assert.False(isMatch) } @@ -61,6 +62,7 @@ func newUUID(t *testing.T) string { func (s *requestMatcherTests) TestCompareHeadersIgnoresIgnoredHeaders() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + matcher := defaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -74,12 +76,13 @@ func (s *requestMatcherTests) TestCompareHeadersIgnoresIgnoredHeaders() { recReq := cassette.Request{Headers: recordedHeaders} // All headers match - assert.True(compareHeaders(&req, recReq, context)) + assert.True(matcher.compareHeaders(&req, recReq)) } func (s *requestMatcherTests) TestCompareHeadersMatchesHeaders() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + matcher := defaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -93,12 +96,13 @@ func (s *requestMatcherTests) TestCompareHeadersMatchesHeaders() { req := http.Request{Header: reqHeaders} recReq := cassette.Request{Headers: recordedHeaders} - assert.True(compareHeaders(&req, recReq, context)) + assert.True(matcher.compareHeaders(&req, recReq)) } func (s *requestMatcherTests) TestCompareHeadersFailsMissingRecHeader() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + matcher := defaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -116,12 +120,13 @@ func (s *requestMatcherTests) TestCompareHeadersFailsMissingRecHeader() { // add a new header to the just req reqHeaders[header2] = headerValue - assert.False(compareHeaders(&req, recReq, context)) + assert.False(matcher.compareHeaders(&req, recReq)) } func (s *requestMatcherTests) TestCompareHeadersFailsMissingReqHeader() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + matcher := defaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -139,12 +144,13 @@ func (s *requestMatcherTests) TestCompareHeadersFailsMissingReqHeader() { // add a new header to just the recording recordedHeaders[header2] = headerValue - assert.False(compareHeaders(&req, recReq, context)) + assert.False(matcher.compareHeaders(&req, recReq)) } func (s *requestMatcherTests) TestCompareHeadersFailsMismatchedValues() { assert := assert.New(s.T()) context := NewTestContext(func(msg string) { assert.FailNow(msg) }, func(msg string) { s.T().Log(msg) }, func() string { return s.T().Name() }) + matcher := defaultMatcher(context) // populate only ignored headers that do not match reqHeaders := make(http.Header) @@ -164,7 +170,7 @@ func (s *requestMatcherTests) TestCompareHeadersFailsMismatchedValues() { recordedHeaders[header2] = headerValue reqHeaders[header2] = mismatch - assert.False(compareHeaders(&req, recReq, context)) + assert.False(matcher.compareHeaders(&req, recReq)) } func (s *requestMatcherTests) TestCompareURLs() { @@ -174,12 +180,13 @@ func (s *requestMatcherTests) TestCompareURLs() { host := "foo.bar" req := http.Request{URL: &url.URL{Scheme: scheme, Host: host}} recReq := cassette.Request{URL: scheme + "://" + host} + matcher := defaultMatcher(context) - assert.True(compareURLs(&req, recReq, context)) + assert.True(matcher.compareURLs(&req, recReq.URL)) req.URL.Path = "noMatch" - assert.False(compareURLs(&req, recReq, context)) + assert.False(matcher.compareURLs(&req, recReq.URL)) } func (s *requestMatcherTests) TestCompareMethods() { @@ -189,12 +196,13 @@ func (s *requestMatcherTests) TestCompareMethods() { methodPatch := "PATCH" req := http.Request{Method: methodPost} recReq := cassette.Request{Method: methodPost} + matcher := defaultMatcher(context) - assert.True(compareMethods(&req, recReq, context)) + assert.True(matcher.compareMethods(&req, recReq.Method)) req.Method = methodPatch - assert.False(compareMethods(&req, recReq, context)) + assert.False(matcher.compareMethods(&req, recReq.Method)) } func closerFromString(content string) io.ReadCloser { diff --git a/sdk/internal/recording/sanitizer.go b/sdk/internal/recording/sanitizer.go index abaf182ce789..2d86d0021e77 100644 --- a/sdk/internal/recording/sanitizer.go +++ b/sdk/internal/recording/sanitizer.go @@ -15,21 +15,27 @@ import ( type Sanitizer struct { recorder *recorder.Recorder - headersToSanitize map[string]*string + headersToSanitize []string urlSanitizer StringSanitizer bodySanitizer StringSanitizer } +// StringSanitizer is a func that will modify the string pointed to by the parameter into a sanitized value. type StringSanitizer func(*string) +// SanitizedValue is the default placeholder value to be used for sanitized strings. const SanitizedValue string = "sanitized" + +// SanitizedBase64Value is the default placeholder value to be used for sanitized base-64 encoded strings. const SanitizedBase64Value string = "Kg==" var sanitizedValueSlice = []string{SanitizedValue} -func DefaultSanitizer(recorder *recorder.Recorder) *Sanitizer { +// defaultSanitizer returns a new RecordingSanitizer with the default sanitizing behavior. +// To customize sanitization, call AddSanitizedHeaders, AddBodySanitizer, or AddUrlSanitizer. +func defaultSanitizer(recorder *recorder.Recorder) *Sanitizer { // The default sanitizer sanitizes the Authorization header - s := &Sanitizer{headersToSanitize: map[string]*string{"Authorization": nil}, recorder: recorder, urlSanitizer: DefaultStringSanitizer, bodySanitizer: DefaultStringSanitizer} + s := &Sanitizer{headersToSanitize: []string{"Authorization"}, recorder: recorder, urlSanitizer: DefaultStringSanitizer, bodySanitizer: DefaultStringSanitizer} recorder.AddSaveFilter(s.applySaveFilter) return s @@ -51,7 +57,7 @@ func (s *Sanitizer) AddUrlSanitizer(sanitizer StringSanitizer) { } func (s *Sanitizer) sanitizeHeaders(header http.Header) { - for headerName := range s.headersToSanitize { + for _, headerName := range s.headersToSanitize { if _, ok := header[headerName]; ok { header[headerName] = sanitizedValueSlice } diff --git a/sdk/internal/recording/sanitizer_test.go b/sdk/internal/recording/sanitizer_test.go index ee70f602d2a3..f4e6c5d7ff72 100644 --- a/sdk/internal/recording/sanitizer_test.go +++ b/sdk/internal/recording/sanitizer_test.go @@ -40,7 +40,7 @@ func (s *sanitizerTests) TestDefaultSanitizerSanitizesAuthHeader() { rt := NewMockRoundTripper(server) r, _ := recorder.NewAsMode(getTestFileName(s.T(), false), recorder.ModeRecording, rt) - DefaultSanitizer(r) + defaultSanitizer(r) req, _ := http.NewRequest(http.MethodPost, server.URL(), nil) req.Header.Add(authHeader, "superSecret") @@ -68,7 +68,7 @@ func (s *sanitizerTests) TestAddSanitizedHeadersSanitizes() { rt := NewMockRoundTripper(server) r, _ := recorder.NewAsMode(getTestFileName(s.T(), false), recorder.ModeRecording, rt) - target := DefaultSanitizer(r) + target := defaultSanitizer(r) target.AddSanitizedHeaders(customHeader1, customHeader2) req, _ := http.NewRequest(http.MethodPost, server.URL(), nil) @@ -108,7 +108,7 @@ func (s *sanitizerTests) TestAddUrlSanitizerSanitizes() { baseUrl := server.URL() + "/" - target := DefaultSanitizer(r) + target := defaultSanitizer(r) target.AddUrlSanitizer(func(url *string) { *url = strings.Replace(*url, secret, SanitizedValue, -1) })