Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixing internal #15315

Merged
merged 1 commit into from
Aug 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 19 additions & 287 deletions sdk/internal/recording/recording.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
package recording

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
Expand All @@ -20,7 +16,6 @@ import (
"path/filepath"
"strconv"
"strings"
"testing"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/internal/uuid"
Expand All @@ -40,6 +35,7 @@ type Recording struct {
src rand.Source
now *time.Time
Sanitizer *Sanitizer
Matcher *RequestMatcher
c TestContext
}

Expand Down Expand Up @@ -69,8 +65,11 @@ const (
type VariableType string

const (
Default VariableType = "default"
Secret_String VariableType = "secret_string"
// NoSanitization indicates that the recorded value should not be sanitized.
NoSanitization VariableType = "default"
// Secret_String indicates that the recorded value should be replaced with a sanitized value.
Secret_String VariableType = "secret_string"
// Secret_Base64String indicates that the recorded value should be replaced with a sanitized valid base-64 string value.
Secret_Base64String VariableType = "secret_base64String"
)

Expand Down Expand Up @@ -107,17 +106,18 @@ func NewRecording(c TestContext, mode RecordMode) (*Recording, error) {
}

// set the recorder Matcher
recording.Matcher = defaultMatcher(c)
rec.SetMatcher(recording.matchRequest)

// wire up the sanitizer
recording.Sanitizer = DefaultSanitizer(rec)
recording.Sanitizer = defaultSanitizer(rec)

return recording, err
}

// GetRecordedVariable returns a recorded variable. If the variable is not found we return an error
// variableType determines how the recorded variable will be saved. Default indicates that the value should be saved without any sanitation.
func (r *Recording) GetRecordedVariable(name string, variableType VariableType) (string, error) {
// GetEnvVar returns a recorded environment variable. If the variable is not found we return an error.
// variableType determines how the recorded variable will be saved.
func (r *Recording) GetEnvVar(name string, variableType VariableType) (string, error) {
var err error
result, ok := r.previousSessionVariables[name]
if !ok || r.Mode == Live {
Expand All @@ -132,9 +132,10 @@ func (r *Recording) GetRecordedVariable(name string, variableType VariableType)
return *result, err
}

// GetOptionalRecordedVariable returns a recorded variable with a fallback default value
// variableType determines how the recorded variable will be saved. Default indicates that the value should be saved without any sanitation.
func (r *Recording) GetOptionalRecordedVariable(name string, defaultValue string, variableType VariableType) string {
// GetOptionalEnvVar returns a recorded environment variable with a fallback default value.
// default Value configures the fallback value to be returned if the environment variable is not set.
// variableType determines how the recorded variable will be saved.
func (r *Recording) GetOptionalEnvVar(name string, defaultValue string, variableType VariableType) string {
result, ok := r.previousSessionVariables[name]
if !ok || r.Mode == Live {
result = getOptionalEnv(name, defaultValue)
Expand Down Expand Up @@ -280,10 +281,10 @@ func getOptionalEnv(name string, defaultValue string) *string {
}

func (r *Recording) matchRequest(req *http.Request, rec cassette.Request) bool {
isMatch := compareMethods(req, rec, r.c) &&
compareURLs(req, rec, r.c) &&
compareHeaders(req, rec, r.c) &&
compareBodies(req, rec, r.c)
isMatch := r.Matcher.compareMethods(req, rec.Method) &&
r.Matcher.compareURLs(req, rec.URL) &&
r.Matcher.compareHeaders(req, rec) &&
r.Matcher.compareBodies(req, rec.Body)

return isMatch
}
Expand Down Expand Up @@ -432,272 +433,3 @@ var modeMap = map[RecordMode]recorder.Mode{
Live: recorder.ModeDisabled,
Playback: recorder.ModeReplaying,
}

var recordMode, _ = os.LookupEnv("AZURE_RECORD_MODE")
var ModeRecording = "record"
var ModePlayback = "playback"

var baseProxyURLSecure = "localhost:5001"
var baseProxyURL = "localhost:5000"
var startURL = baseProxyURLSecure + "/record/start"
var stopURL = baseProxyURLSecure + "/record/stop"

var recordingId string
var IdHeader = "x-recording-id"
var ModeHeader = "x-recording-mode"
var UpstreamUriHeader = "x-recording-upstream-base-uri"

var tr = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
var client = http.Client{
Transport: tr,
}

type RecordingOptions struct {
MaxRetries int32
UseHTTPS bool
Host string
Scheme string
}

func defaultOptions() *RecordingOptions {
return &RecordingOptions{
MaxRetries: 0,
UseHTTPS: true,
Host: "localhost:5001",
Scheme: "https",
}
}

func (r RecordingOptions) HostScheme() string {
if r.UseHTTPS {
return "https://localhost:5001"
}
return "http://localhost:5000"
}

func getTestId(t *testing.T) string {
cwd, err := os.Getwd()
if err != nil {
t.Errorf("Could not find current working directory")
}
cwd = "./recordings/" + t.Name() + ".json"
return cwd
}

func StartRecording(t *testing.T, options *RecordingOptions) error {
if options == nil {
options = defaultOptions()
}
if recordMode == "" {
t.Log("AZURE_RECORD_MODE was not set, options are \"record\" or \"playback\". \nDefaulting to playback")
recordMode = "playback"
} else {
t.Log("AZURE_RECORD_MODE: ", recordMode)
}
testId := getTestId(t)

url := fmt.Sprintf("%v/%v/start", options.HostScheme(), recordMode)

req, err := http.NewRequest("POST", url, nil)
if err != nil {
return err
}

req.Header.Set("x-recording-file", testId)

resp, err := client.Do(req)
if err != nil {
return err
}
recordingId = resp.Header.Get(IdHeader)
return nil
}

func StopRecording(t *testing.T, options *RecordingOptions) error {
if options == nil {
options = defaultOptions()
}

url := fmt.Sprintf("%v/%v/stop", options.HostScheme(), recordMode)
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return err
}
if recordingId == "" {
return errors.New("Recording ID was never set. Did you call StartRecording?")
}
req.Header.Set("x-recording-id", recordingId)
_, err = client.Do(req)
if err != nil {
t.Errorf(err.Error())
}
return nil
}

func AddUriSanitizer(replacement, regex string, options *RecordingOptions) error {
if options == nil {
options = defaultOptions()
}
url := fmt.Sprintf("%v/Admin/AddSanitizer", options.HostScheme())
req, err := http.NewRequest("POST", url, nil)
if err != nil {
return err
}
req.Header.Set("x-abstraction-identifier", "UriRegexSanitizer")
bodyContent := map[string]string{
"value": replacement,
"regex": regex,
}
marshalled, err := json.Marshal(bodyContent)
if err != nil {
return err
}
req.Body = ioutil.NopCloser(bytes.NewReader(marshalled))
req.ContentLength = int64(len(marshalled))
_, err = client.Do(req)
return err
}

func (o *RecordingOptions) Init() {
if o.MaxRetries != 0 {
o.MaxRetries = 0
}
if o.UseHTTPS {
o.Host = baseProxyURLSecure
o.Scheme = "https"
} else {
o.Host = baseProxyURL
o.Scheme = "http"
}
}

// type recordingPolicy struct {
// options RecordingOptions
// }

// func NewRecordingPolicy(o *RecordingOptions) azcore.Policy {
// if o == nil {
// o = &RecordingOptions{}
// }
// p := &recordingPolicy{options: *o}
// p.options.init()
// return p
// }

// func (p *recordingPolicy) Do(req *azcore.Request) (resp *azcore.Response, err error) {
// originalURLHost := req.URL.Host
// req.URL.Scheme = "https"
// req.URL.Host = p.options.host
// req.Host = p.options.host

// req.Header.Set(UpstreamUriHeader, fmt.Sprintf("%v://%v", p.options.scheme, originalURLHost))
// req.Header.Set(ModeHeader, recordMode)
// req.Header.Set(recordingIdHeader, recordingId)

// return req.Next()
// }

// This looks up an environment variable and if it is not found, returns the recordedValue
func GetEnvVariable(t *testing.T, varName string, recordedValue string) string {
val, ok := os.LookupEnv(varName)
if !ok {
t.Logf("Could not find environment variable: %v", varName)
return recordedValue
}
return val
}

func LiveOnly(t *testing.T) {
if GetRecordMode() != ModeRecording {
t.Skip("Live Test Only")
}
}

// Function for sleeping during a test for `duration` seconds. This method will only execute when
// AZURE_RECORD_MODE = "record", if a test is running in playback this will be a noop.
func Sleep(duration int) {
if GetRecordMode() == ModeRecording {
time.Sleep(time.Duration(duration) * time.Second)
}
}

func GetRecordingId() string {
return recordingId
}

func GetRecordMode() string {
return recordMode
}

func InPlayback() bool {
return GetRecordMode() == ModePlayback
}

func InRecord() bool {
return GetRecordMode() == ModeRecording
}

// type FakeCredential struct {
// accountName string
// accountKey string
// }

// func NewFakeCredential(accountName, accountKey string) *FakeCredential {
// return &FakeCredential{
// accountName: accountName,
// accountKey: accountKey,
// }
// }

// func (f *FakeCredential) AuthenticationPolicy(azcore.AuthenticationPolicyOptions) azcore.Policy {
// return azcore.PolicyFunc(func(req *azcore.Request) (*azcore.Response, error) {
// authHeader := strings.Join([]string{"Authorization ", f.accountName, ":", f.accountKey}, "")
// req.Request.Header.Set(azcore.HeaderAuthorization, authHeader)
// return req.Next()
// })
// }

func getRootCas() (*x509.CertPool, error) {
localFile, ok := os.LookupEnv("PROXY_CERT")

rootCAs, err := x509.SystemCertPool()
if err != nil {
rootCAs = x509.NewCertPool()
}

if !ok {
fmt.Println("Could not find path to proxy certificate, set the environment variable 'PROXY_CERT' to the location of your certificate")
return rootCAs, nil
}

cert, err := ioutil.ReadFile(*&localFile)
if err != nil {
fmt.Println("error opening cert file")
return nil, err
}

if ok := rootCAs.AppendCertsFromPEM(cert); !ok {
fmt.Println("No certs appended, using system certs only")
}

return rootCAs, nil
}

func GetHTTPClient() (*http.Client, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()

rootCAs, err := getRootCas()
if err != nil {
return nil, err
}

transport.TLSClientConfig.RootCAs = rootCAs
transport.TLSClientConfig.MinVersion = tls.VersionTLS12

defaultHttpClient := &http.Client{
Transport: transport,
}
return defaultHttpClient, nil
}
Loading