From e704f09b291ab48dc65fe5a9698e03822012548e Mon Sep 17 00:00:00 2001 From: Alan Hamlett Date: Sat, 22 May 2021 07:52:33 -0700 Subject: [PATCH] Prevent using DefaultTransport --- draft/usage.md | 105 ------------------------------------------- pkg/api/option.go | 30 ++----------- pkg/api/transport.go | 8 ++++ 3 files changed, 12 insertions(+), 131 deletions(-) delete mode 100644 draft/usage.md diff --git a/draft/usage.md b/draft/usage.md deleted file mode 100644 index bb85a533..00000000 --- a/draft/usage.md +++ /dev/null @@ -1,105 +0,0 @@ -import ( - "net/http" - "time" - - "github.com/wakatime/wakatime-cli/pkg/api" - "github.com/wakatime/wakatime-cli/pkg/filestats" - "github.com/wakatime/wakatime-cli/pkg/deps" - "github.com/wakatime/wakatime-cli/pkg/heartbeat" - "github.com/wakatime/wakatime-cli/pkg/language" - "github.com/wakatime/wakatime-cli/pkg/log" - "github.com/wakatime/wakatime-cli/pkg/offline" - "github.com/wakatime/wakatime-cli/pkg/project" -) - -const ( - queueDBFile = ".wakatime.db" - queueDBTable = "heartbeat_2" -) - -func main() { - withAuth, err := api.WithAuth(api.BasicAuth{ - Secret: args.APIKey, - }) - if err != nil { - log.Fatalf(err) - } - - clientOpts := []api.Option{ - withAuth, - api.WithHostName(args.HostName), - } - - if args.SSLCert != nil { - clientOpts = append(clientOpts, api.WithSSL(args.SSLCert)) - } - - if args.Timeout != nil { - clientOpts = append(clientOpts, api.WithTimeout(args.Timeout * time.Second)) - } - - if args.Plugin != nil { - clientOpts = append(clientOpts, api.WithUserAgentFromPlugin(args.Plugin)) - } else { - clientOpts = append(clientOpts, api.WithUserAgent()) - } - - client := api.NewClient(baseURL, http.DefaultClient, clientOpts...) - - var withDepsDetection heartbeat.HandleOption - if args.Localfile == "" { - withDepsDetection = deps.WithDetection() - } else - withDepsDetection = deps.WithDetectionOnFile(args.Localfile) - } - - var withFilestatsDetection heartbeat.HandleOption - if args.Localfile == "" { - withFilestatsDetection = filestats.WithDetection() - } else - withFilestatsDetection = filestats.WithDetectionOnFile(args.Localfile) - } - - handleOpts := []heartbeat.HandleOption{ - heartbeat.WithSanitization(heartbeat.SanitizeConfig{ - HideBranchNames: args.HideBranchNames, - HideFileNames: args.HideFileNames, - HideProjectNames: args.HideProjectNames, - }), - offline.WithQueue(queueDBFile, queueDBTable), - language.WithDetection(language.Config{ - Alternative: args.AlternativeLanguage, - Overwrite: args.Language, - LocalFile: args.LocalFile, - }), - withDepsDetection, - withFilestatsDetection, - project.WithDetection(project.Config{ - Alternative: args.AlternativeProject, - Overwrite: args.Project, - LocalFile: args.LocalFile, - }), - heartbeat.WithValidation(heartbeat.ValidateConfig{ - Exclude: args.Exclude, - ExcludeUnknownProject: args.ExcludeUnknownProject, - Include: args.Include, - IncludeOnlyWithProjectFile: args.IncludeOnlyWithProjectFile, - ), - } - handle := heartbeat.NewHandle(client, handleOpts...) - - hh := []Heartbeat{ - { - Category: args.Category, - Entity: args.Entity, - EntityType: args.EntityType, - IsWrite: args.IsWrite, - Time: args.Time, - UserAgent: arg.UserAgent, - } - } - _, err := handle(hh) - if err != nil { - log.Fatalf(err) - } -} diff --git a/pkg/api/option.go b/pkg/api/option.go index 6e9c7e08..4af541a0 100644 --- a/pkg/api/option.go +++ b/pkg/api/option.go @@ -49,12 +49,7 @@ func WithHostname(hostname string) Option { // WithDisableSSLVerify disables verification of insecure certificates. func WithDisableSSLVerify() Option { return func(c *Client) { - var transport *http.Transport - if c.client.Transport == nil { - transport = http.DefaultTransport.(*http.Transport).Clone() - } else { - transport = c.client.Transport.(*http.Transport).Clone() - } + var transport *http.Transport = GetOrCreateTransport(c) tlsConfig := transport.TLSClientConfig tlsConfig.InsecureSkipVerify = true @@ -88,13 +83,7 @@ func WithNTLM(creds string) (Option, error) { return func(c *Client) { withAuth(c) - var transport http.RoundTripper - if c.client.Transport == nil { - transport = http.DefaultTransport - } else { - transport = c.client.Transport.(*http.Transport).Clone() - } - + var transport *http.Transport = GetOrCreateTransport(c) c.client.Transport = ntlmssp.Negotiator{ RoundTripper: transport, } @@ -134,13 +123,8 @@ func WithProxy(proxyURL string) (Option, error) { } return func(c *Client) { - transport := http.DefaultTransport.(*http.Transport).Clone() - if c.client.Transport != nil { - transport = c.client.Transport.(*http.Transport).Clone() - } - + var transport *http.Transport = GetOrCreateTransport(c) transport.Proxy = http.ProxyURL(u) - c.client.Transport = transport }, nil } @@ -166,13 +150,7 @@ func WithSSLCertFile(filepath string) (Option, error) { // WithSSLCertPool overrides the default CA cert pool to trust specified cert pool. func WithSSLCertPool(caCertPool *x509.CertPool) (Option, error) { return func(c *Client) { - var transport *http.Transport - if c.client.Transport == nil { - transport = http.DefaultTransport.(*http.Transport).Clone() - } else { - transport = c.client.Transport.(*http.Transport).Clone() - } - + var transport *http.Transport = GetOrCreateTransport(c) tlsConfig := transport.TLSClientConfig tlsConfig.RootCAs = caCertPool transport.TLSClientConfig = tlsConfig diff --git a/pkg/api/transport.go b/pkg/api/transport.go index 44ed3bfc..cf51ee5c 100644 --- a/pkg/api/transport.go +++ b/pkg/api/transport.go @@ -16,3 +16,11 @@ func NewTransport() *http.Transport { ForceAttemptHTTP2: true, } } + +// GetOrCreateTransport gets the client's Transport if already exists, or initializes a new one +func GetOrCreateTransport(c *Client) *http.Transport { + if c.client.Transport != nil { + return c.client.Transport.(*http.Transport).Clone() + } + return NewTransport() +}