diff --git a/testhelper/gittest/gittest.go b/testhelper/gittest/gittest.go new file mode 100644 index 00000000..8600e667 --- /dev/null +++ b/testhelper/gittest/gittest.go @@ -0,0 +1,124 @@ +// Package gittest provides a test helper for creating a git server that serves a git repository. +package gittest + +import ( + "encoding/pem" + "fmt" + "net" + "net/http/cgi" + "net/url" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-go-kit/testhelper/httptest" +) + +type Server struct { + *httptest.Server + URL string +} + +// NewHttpServer creates a new httptest.Server that serves a git repository from the given sourcePath. +func NewHttpServer(t testing.TB, sourcePath string) *Server { + return newServer(t, sourcePath, false) +} + +// NewHttpsServer creates a new httptest.Server that serves a git repository from the given sourcePath. +func NewHttpsServer(t testing.TB, sourcePath string) *Server { + return newServer(t, sourcePath, true) +} + +// newServer creates a new httptest.Server that serves a git repository from the given sourcePath. +func newServer(t testing.TB, sourcePath string, secure bool) *Server { + t.Helper() + tempDir := t.TempDir() + org := "org" + repo := filepath.Base(sourcePath) + if !strings.HasSuffix(repo, ".git") { + repo = repo + ".git" + } + source := sourcePath + if !strings.HasSuffix(sourcePath, "/") { + source = source + "/" + } + gitRoot := filepath.Join(tempDir, org, repo) + require.NoErrorf(t, os.MkdirAll(gitRoot, os.ModePerm), "should be able to create %s", gitRoot) + require.NoErrorf(t, exec.Command("rsync", "--recursive", source, gitRoot).Run(), "should be able to copy %s to %s", source, gitRoot) + gitPath, err := exec.LookPath("git") + require.NoError(t, err, "should be able to find git in PATH") + require.NoError(t, exec.Command("git", "init", gitRoot).Run(), "should be able to initialize git repository") + require.NoError(t, exec.Command("git", "-C", gitRoot, "add", ".").Run(), "should be able to add files to git repository") + commitCmd := exec.Command("git", "-C", gitRoot, "commit", "-m", "initial commit") + commitCmd.Env = append(commitCmd.Env, "GIT_AUTHOR_NAME=git test", "GIT_COMMITTER_NAME=git test", "GIT_AUTHOR_EMAIL=gittest@example.com", "GIT_COMMITTER_EMAIL=gittest@example.com") + require.NoError(t, commitCmd.Run(), "should be able to commit files to git repository") + + handler := &cgi.Handler{ + Path: gitPath, + Args: []string{"http-backend"}, + Env: []string{ + fmt.Sprintf("GIT_PROJECT_ROOT=%s", tempDir), + "GIT_HTTP_EXPORT_ALL=true", + }, + } + + localIP := getLocalIP(t) + var s *httptest.Server + if !secure { + s = httptest.NewServer(handler) + } else { + s = httptest.NewTLSServer(localIP, handler) + certPath := filepath.Join(tempDir, "server.crt") + require.NoError(t, writeServerCA(s, certPath)) + t.Setenv("SSL_CERT_FILE", certPath) + + } + serverURL, err := url.Parse(s.URL) + require.NoError(t, err) + _, port, err := net.SplitHostPort(serverURL.Host) + require.NoError(t, err) + serverURL.Host = net.JoinHostPort(getLocalIP(t), port) + url := serverURL.String() + "/" + org + "/" + repo + return &Server{ + Server: s, + URL: url, + } +} + +// getServerCA returns a byte slice containing the PEM encoding of the server's CA certificate +func (s *Server) GetServerCA() []byte { + return getServerCA(s.Server) +} + +func getServerCA(server *httptest.Server) []byte { + return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: server.TLS.Certificates[0].Certificate[0]}) +} + +// writeServerCA writes the PEM-encoded server certificate to a given path +func writeServerCA(server *httptest.Server, path string) error { + certOut, err := os.Create(path) + if err != nil { + return err + } + defer certOut.Close() + + if _, err := certOut.Write(getServerCA(server)); err != nil { + return err + } + + return nil +} + +func getLocalIP(t testing.TB) string { + conn, err := net.Dial("udp", "8.8.8.8:80") + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close()) + }() + localAddr := conn.LocalAddr().(*net.UDPAddr) + return localAddr.IP.String() +} diff --git a/testhelper/gittest/gittest_test.go b/testhelper/gittest/gittest_test.go new file mode 100644 index 00000000..23799e44 --- /dev/null +++ b/testhelper/gittest/gittest_test.go @@ -0,0 +1,30 @@ +package gittest_test + +import ( + "os/exec" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/rudderlabs/rudder-go-kit/testhelper/gittest" +) + +func TestGitServer(t *testing.T) { + t.Run("http", func(t *testing.T) { + s := gittest.NewHttpServer(t, "testdata/gitrepo") + defer s.Close() + tempDir := t.TempDir() + url := s.URL + require.NoError(t, exec.Command("git", "-c", "http.sslVerify=false", "clone", url, tempDir).Run(), "should be able to clone the repository") + require.FileExists(t, tempDir+"/README.md", "README.md should exist in the cloned repository") + }) + + t.Run("https", func(t *testing.T) { + s := gittest.NewHttpsServer(t, "testdata/gitrepo") + defer s.Close() + tempDir := t.TempDir() + url := s.URL + require.NoError(t, exec.Command("git", "-c", "http.sslVerify=false", "clone", url, tempDir).Run(), "should be able to clone the repository") + require.FileExists(t, tempDir+"/README.md", "README.md should exist in the cloned repository") + }) +} diff --git a/testhelper/gittest/testdata/gitrepo/README.md b/testhelper/gittest/testdata/gitrepo/README.md new file mode 100644 index 00000000..5e1c309d --- /dev/null +++ b/testhelper/gittest/testdata/gitrepo/README.md @@ -0,0 +1 @@ +Hello World \ No newline at end of file diff --git a/testhelper/httptest/httptest.go b/testhelper/httptest/httptest.go index 93b71ee3..6e21e662 100644 --- a/testhelper/httptest/httptest.go +++ b/testhelper/httptest/httptest.go @@ -1,49 +1,465 @@ +// Package httptest provides similar functionality to the net/http/httptest package, but with some additional features: +// +// - It allows you to listen to all interfaces, not just localhost. +// - For TLS servers, it allows you to specify the host name to use in the certificate. package httptest import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "flag" "fmt" + "log" + "math/big" "net" "net/http" - nethttptest "net/http/httptest" + "os" + "strings" + "sync" + "time" ) -// NewServer starts a new httptest server that listens on all interfaces, contrary to the standard net/httptest.Server that listens only on localhost. -// This is useful when you want to access the test http server from within a docker container. +// A Server is an HTTP server listening on a system-chosen port on the +// local loopback interface, for use in end-to-end HTTP tests. +type Server struct { + URL string // base URL of form http://ipaddr:port with no trailing slash + Listener net.Listener + + // EnableHTTP2 controls whether HTTP/2 is enabled + // on the server. It must be set between calling + // NewUnstartedServer and calling Server.StartTLS. + EnableHTTP2 bool + + // TLS is the optional TLS configuration, populated with a new config + // after TLS is started. If set on an unstarted server before StartTLS + // is called, existing fields are copied into the new config. + TLS *tls.Config + + // Config may be changed after calling NewUnstartedServer and + // before Start or StartTLS. + Config *http.Server + + // certificate is a parsed version of the TLS config certificate, if present. + certificate *x509.Certificate + + // wg counts the number of outstanding HTTP requests on this server. + // Close blocks until all requests are finished. + wg sync.WaitGroup + + mu sync.Mutex // guards closed and conns + closed bool + conns map[net.Conn]http.ConnState // except terminal states + + // client is configured for use with the server. + // Its transport is automatically closed when Close is called. + client *http.Client +} + +func newLocalListener() net.Listener { + if serveFlag != "" { + l, err := net.Listen("tcp", serveFlag) + if err != nil { + panic(fmt.Sprintf("httptest: failed to listen on %v: %v", serveFlag, err)) + } + return l + } + l, err := net.Listen("tcp", "0.0.0.0:0") + if err != nil { + if l, err = net.Listen("tcp6", "[::]:0"); err != nil { + panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err)) + } + } + return l +} + +// When debugging a particular http server-based test, +// this flag lets you run +// +// go test -run='^BrokenTest$' -httptest.serve=127.0.0.1:8000 +// +// to start the broken server so you can interact with it manually. +// We only register this flag if it looks like the caller knows about it +// and is trying to use it as we don't want to pollute flags and this +// isn't really part of our API. Don't depend on this. +var serveFlag string + +func init() { + if strSliceContainsPrefix(os.Args, "-httptest.serve=") || strSliceContainsPrefix(os.Args, "--httptest.serve=") { + flag.StringVar(&serveFlag, "httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks.") + } +} + +func strSliceContainsPrefix(v []string, pre string) bool { + for _, s := range v { + if strings.HasPrefix(s, pre) { + return true + } + } + return false +} + +// NewServer starts and returns a new [Server]. +// The caller should call Close when finished, to shut it down. func NewServer(handler http.Handler) *Server { - ts := NewUnStartedServer(handler) - ts.start() + ts := NewUnstartedServer(handler) + ts.Start() return ts } -// Server wraps net/httptest.Server to listen on all network interfaces -type Server struct { - *nethttptest.Server +// NewUnstartedServer returns a new [Server] but doesn't start it. +// +// After changing its configuration, the caller should call Start or +// StartTLS. +// +// The caller should call Close when finished, to shut it down. +func NewUnstartedServer(handler http.Handler) *Server { + return &Server{ + Listener: newLocalListener(), + Config: &http.Server{Handler: handler}, + } +} + +// Start starts a server from NewUnstartedServer. +func (s *Server) Start() { + if s.URL != "" { + panic("Server already started") + } + if s.client == nil { + s.client = &http.Client{Transport: &http.Transport{}} + } + _, port, _ := net.SplitHostPort(s.Listener.Addr().String()) + s.URL = "http://" + net.JoinHostPort("127.0.0.1", port) + s.wrap() + s.goServe() + if serveFlag != "" { + fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL) + select {} + } } -func (s *Server) start() { - s.Server.Start() - _, port, err := net.SplitHostPort(s.Listener.Addr().String()) +// StartTLS starts TLS on a server from NewUnstartedServer. +func (s *Server) StartTLS(host string) { + if s.URL != "" { + panic("Server already started") + } + if s.client == nil { + s.client = &http.Client{} + } + hosts := []string{"127.0.0.1", "::1", "example.com"} + if host != "" { + hosts = append(hosts, host) + } + pem, key := generateCert(strings.Join(hosts, ",")) + cert, err := tls.X509KeyPair(pem, key) if err != nil { - panic(fmt.Sprintf("httptest: failed to parse listener address: %v", err)) + panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) + } + + existingConfig := s.TLS + if existingConfig != nil { + s.TLS = existingConfig.Clone() + } else { + s.TLS = new(tls.Config) + } + if s.TLS.NextProtos == nil { + nextProtos := []string{"http/1.1"} + if s.EnableHTTP2 { + nextProtos = []string{"h2"} + } + s.TLS.NextProtos = nextProtos } - s.URL = fmt.Sprintf("http://%s:%s", "localhost", port) + if len(s.TLS.Certificates) == 0 { + s.TLS.Certificates = []tls.Certificate{cert} + } + s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0]) + if err != nil { + panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) + } + certpool := x509.NewCertPool() + certpool.AddCert(s.certificate) + s.client.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: certpool, + }, + ForceAttemptHTTP2: s.EnableHTTP2, + } + s.Listener = tls.NewListener(s.Listener, s.TLS) + + _, port, _ := net.SplitHostPort(s.Listener.Addr().String()) + listenerhost := "127.0.0.1" + if host != "" { + listenerhost = host + } + + s.URL = "https://" + net.JoinHostPort(listenerhost, port) + s.wrap() + s.goServe() } -func NewUnStartedServer(handler http.Handler) *Server { - return &Server{&nethttptest.Server{ - Listener: newListener(), - Config: &http.Server{Handler: handler}, - }} +// NewTLSServer starts and returns a new [Server] using TLS. +// The caller should call Close when finished, to shut it down. +func NewTLSServer(host string, handler http.Handler) *Server { + ts := NewUnstartedServer(handler) + ts.StartTLS(host) + return ts +} + +type closeIdleTransport interface { + CloseIdleConnections() } -func newListener() net.Listener { - listener, tcpError := net.Listen("tcp", "0.0.0.0:0") - if tcpError == nil { - return listener +// Close shuts down the server and blocks until all outstanding +// requests on this server have completed. +func (s *Server) Close() { + s.mu.Lock() + if !s.closed { + s.closed = true + s.Listener.Close() + s.Config.SetKeepAlivesEnabled(false) + for c, st := range s.conns { + // Force-close any idle connections (those between + // requests) and new connections (those which connected + // but never sent a request). StateNew connections are + // super rare and have only been seen (in + // previously-flaky tests) in the case of + // socket-late-binding races from the http Client + // dialing this server and then getting an idle + // connection before the dial completed. There is thus + // a connected connection in StateNew with no + // associated Request. We only close StateIdle and + // StateNew because they're not doing anything. It's + // possible StateNew is about to do something in a few + // milliseconds, but a previous CL to check again in a + // few milliseconds wasn't liked (early versions of + // https://golang.org/cl/15151) so now we just + // forcefully close StateNew. The docs for Server.Close say + // we wait for "outstanding requests", so we don't close things + // in StateActive. + if st == http.StateIdle || st == http.StateNew { + s.closeConn(c) + } + } + // If this server doesn't shut down in 5 seconds, tell the user why. + t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo) + defer t.Stop() + } + s.mu.Unlock() + + // Not part of httptest.Server's correctness, but assume most + // users of httptest.Server will be using the standard + // transport, so help them out and close any idle connections for them. + if t, ok := http.DefaultTransport.(closeIdleTransport); ok { + t.CloseIdleConnections() + } + + // Also close the client idle connections. + if s.client != nil { + if t, ok := s.client.Transport.(closeIdleTransport); ok { + t.CloseIdleConnections() + } + } + + s.wg.Wait() +} + +func (s *Server) logCloseHangDebugInfo() { + s.mu.Lock() + defer s.mu.Unlock() + var buf strings.Builder + buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n") + for c, st := range s.conns { + fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st) + } + log.Print(buf.String()) +} + +// CloseClientConnections closes any open HTTP connections to the test Server. +func (s *Server) CloseClientConnections() { + s.mu.Lock() + nconn := len(s.conns) + ch := make(chan struct{}, nconn) + for c := range s.conns { + go s.closeConnChan(c, ch) + } + s.mu.Unlock() + + // Wait for outstanding closes to finish. + // + // Out of paranoia for making a late change in Go 1.6, we + // bound how long this can wait, since golang.org/issue/14291 + // isn't fully understood yet. At least this should only be used + // in tests. + timer := time.NewTimer(5 * time.Second) + defer timer.Stop() + for i := 0; i < nconn; i++ { + select { + case <-ch: + case <-timer.C: + // Too slow. Give up. + return + } + } +} + +// Certificate returns the certificate used by the server, or nil if +// the server doesn't use TLS. +func (s *Server) Certificate() *x509.Certificate { + return s.certificate +} + +// Client returns an HTTP client configured for making requests to the server. +// It is configured to trust the server's TLS test certificate and will +// close its idle connections on [Server.Close]. +func (s *Server) Client() *http.Client { + return s.client +} + +func (s *Server) goServe() { + s.wg.Add(1) + go func() { + defer s.wg.Done() + _ = s.Config.Serve(s.Listener) + }() +} + +// wrap installs the connection state-tracking hook to know which +// connections are idle. +func (s *Server) wrap() { + oldHook := s.Config.ConnState + s.Config.ConnState = func(c net.Conn, cs http.ConnState) { + s.mu.Lock() + defer s.mu.Unlock() + + switch cs { + case http.StateNew: + if _, exists := s.conns[c]; exists { + panic("invalid state transition") + } + if s.conns == nil { + s.conns = make(map[net.Conn]http.ConnState) + } + // Add c to the set of tracked conns and increment it to the + // waitgroup. + s.wg.Add(1) + s.conns[c] = cs + if s.closed { + // Probably just a socket-late-binding dial from + // the default transport that lost the race (and + // thus this connection is now idle and will + // never be used). + s.closeConn(c) + } + case http.StateActive: + if oldState, ok := s.conns[c]; ok { + if oldState != http.StateNew && oldState != http.StateIdle { + panic("invalid state transition") + } + s.conns[c] = cs + } + case http.StateIdle: + if oldState, ok := s.conns[c]; ok { + if oldState != http.StateActive { + panic("invalid state transition") + } + s.conns[c] = cs + } + if s.closed { + s.closeConn(c) + } + case http.StateHijacked, http.StateClosed: + // Remove c from the set of tracked conns and decrement it from the + // waitgroup, unless it was previously removed. + if _, ok := s.conns[c]; ok { + delete(s.conns, c) + // Keep Close from returning until the user's ConnState hook + // (if any) finishes. + defer s.wg.Done() + } + } + if oldHook != nil { + oldHook(c, cs) + } + } +} + +// closeConn closes c. +// s.mu must be held. +func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) } + +// closeConnChan is like closeConn, but takes an optional channel to receive a value +// when the goroutine closing c is done. +func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) { + c.Close() + if done != nil { + done <- struct{}{} + } +} + +func generateCert(host string) (cert, key []byte) { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + log.Fatalf("Failed to generate private key: %v", err) + } + keyUsage := x509.KeyUsageDigitalSignature + keyUsage |= x509.KeyUsageKeyEncipherment + + var notBefore time.Time + notBefore, err = time.Parse("Jan 2 15:04:05 2006", "Jan 1 00:00:00 1970") + if err != nil { + log.Fatalf("Failed to parse creation date: %v", err) + } + notAfter := notBefore.Add(1000000 * time.Hour) + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + log.Fatalf("Failed to generate serial number: %v", err) + } + + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{"Acme Co"}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + KeyUsage: keyUsage, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + hosts := strings.Split(host, ",") + for _, h := range hosts { + if ip := net.ParseIP(h); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, h) + } + } + template.IsCA = true + template.KeyUsage |= x509.KeyUsageCertSign + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + log.Fatalf("Failed to create certificate: %v", err) + } + var certOut bytes.Buffer + if err != nil { + log.Fatalf("Failed to open cert.pem for writing: %v", err) + } + if err := pem.Encode(&certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { + log.Fatalf("Failed to write data to cert.pem: %v", err) + } + + var keyOut bytes.Buffer + privBytes, err := x509.MarshalPKCS8PrivateKey(priv) + if err != nil { + log.Fatalf("Unable to marshal private key: %v", err) } - listener, tcp6Error := net.Listen("tcp6", "[::]:0") - if tcp6Error == nil { - return listener + if err := pem.Encode(&keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { + log.Fatalf("Failed to write data to key.pem: %v", err) } - panic(fmt.Sprintf("httptest: failed to start listener on a port for tcp (%v) and tcp6 (%v)", tcpError, tcp6Error)) + return certOut.Bytes(), keyOut.Bytes() } diff --git a/testhelper/httptest/httptest_test.go b/testhelper/httptest/httptest_test.go index f763833f..b329327d 100644 --- a/testhelper/httptest/httptest_test.go +++ b/testhelper/httptest/httptest_test.go @@ -46,7 +46,7 @@ func TestServer(t *testing.T) { func TestUnStartedServer(t *testing.T) { // create a server which is not started - httpUnStartedServer := kithttptest.NewUnStartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + httpUnStartedServer := kithttptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("Hello, world!")) }))