diff --git a/pkg/transport/keepalive_listener_test.go b/pkg/transport/keepalive_listener_test.go index f8c062cf72f8..425f53368b54 100644 --- a/pkg/transport/keepalive_listener_test.go +++ b/pkg/transport/keepalive_listener_test.go @@ -18,7 +18,6 @@ import ( "crypto/tls" "net" "net/http" - "os" "testing" ) @@ -50,12 +49,12 @@ func TestNewKeepAliveListener(t *testing.T) { } // tls - tmp, err := createTempFile([]byte("XXX")) + tlsinfo, del, err := createSelfCert() if err != nil { t.Fatalf("unable to create tmpfile: %v", err) } - defer os.Remove(tmp) - tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp} + defer del() + tlsInfo := TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile} tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil) tlscfg, err := tlsInfo.ServerConfig() if err != nil { diff --git a/pkg/transport/listener_test.go b/pkg/transport/listener_test.go index d5ff317892fd..b39b3d53760f 100644 --- a/pkg/transport/listener_test.go +++ b/pkg/transport/listener_test.go @@ -24,18 +24,16 @@ import ( "time" ) -func createTempFile(b []byte) (string, error) { - f, err := ioutil.TempFile("", "etcd-test-tls-") - if err != nil { - return "", err +func createSelfCert() (*TLSInfo, func(), error) { + d, terr := ioutil.TempDir("", "etcd-test-tls-") + if terr != nil { + return nil, nil, terr } - defer f.Close() - - if _, err = f.Write(b); err != nil { - return "", err + info, err := SelfCert(d, []string{"127.0.0.1"}) + if err != nil { + return nil, nil, err } - - return f.Name(), nil + return &info, func() { os.RemoveAll(d) }, nil } func fakeCertificateParserFunc(cert tls.Certificate, err error) func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) { @@ -47,14 +45,12 @@ func fakeCertificateParserFunc(cert tls.Certificate, err error) func(certPEMBloc // TestNewListenerTLSInfo tests that NewListener with valid TLSInfo returns // a TLS listener that accepts TLS connections. func TestNewListenerTLSInfo(t *testing.T) { - tmp, err := createTempFile([]byte("XXX")) + tlsInfo, del, err := createSelfCert() if err != nil { - t.Fatalf("unable to create tmpfile: %v", err) + t.Fatalf("unable to create cert: %v", err) } - defer os.Remove(tmp) - tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp} - tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil) - testNewListenerTLSInfoAccept(t, tlsInfo) + defer del() + testNewListenerTLSInfoAccept(t, *tlsInfo) } func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) { @@ -62,13 +58,17 @@ func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) { if err != nil { t.Fatalf("unexpected serverConfig error: %v", err) } + tlscfg.InsecureSkipVerify = true ln, err := NewListener("127.0.0.1:0", "https", tlscfg) if err != nil { t.Fatalf("unexpected NewListener error: %v", err) } defer ln.Close() - go http.Get("https://" + ln.Addr().String()) + tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + cli := &http.Client{Transport: tr} + go cli.Get("https://" + ln.Addr().String()) + conn, err := ln.Accept() if err != nil { t.Fatalf("unexpected Accept error: %v", err) @@ -87,25 +87,25 @@ func TestNewListenerTLSEmptyInfo(t *testing.T) { } func TestNewTransportTLSInfo(t *testing.T) { - tmp, err := createTempFile([]byte("XXX")) + tlsinfo, del, err := createSelfCert() if err != nil { - t.Fatalf("Unable to prepare tmpfile: %v", err) + t.Fatalf("unable to create cert: %v", err) } - defer os.Remove(tmp) + defer del() tests := []TLSInfo{ {}, { - CertFile: tmp, - KeyFile: tmp, + CertFile: tlsinfo.CertFile, + KeyFile: tlsinfo.KeyFile, }, { - CertFile: tmp, - KeyFile: tmp, - CAFile: tmp, + CertFile: tlsinfo.CertFile, + KeyFile: tlsinfo.KeyFile, + CAFile: tlsinfo.CAFile, }, { - CAFile: tmp, + CAFile: tlsinfo.CAFile, }, } @@ -159,17 +159,17 @@ func TestTLSInfoEmpty(t *testing.T) { } func TestTLSInfoMissingFields(t *testing.T) { - tmp, err := createTempFile([]byte("XXX")) + tlsinfo, del, err := createSelfCert() if err != nil { - t.Fatalf("Unable to prepare tmpfile: %v", err) + t.Fatalf("unable to create cert: %v", err) } - defer os.Remove(tmp) + defer del() tests := []TLSInfo{ - {CertFile: tmp}, - {KeyFile: tmp}, - {CertFile: tmp, CAFile: tmp}, - {KeyFile: tmp, CAFile: tmp}, + {CertFile: tlsinfo.CertFile}, + {KeyFile: tlsinfo.KeyFile}, + {CertFile: tlsinfo.CertFile, CAFile: tlsinfo.CAFile}, + {KeyFile: tlsinfo.KeyFile, CAFile: tlsinfo.CAFile}, } for i, info := range tests { @@ -184,30 +184,29 @@ func TestTLSInfoMissingFields(t *testing.T) { } func TestTLSInfoParseFuncError(t *testing.T) { - tmp, err := createTempFile([]byte("XXX")) + tlsinfo, del, err := createSelfCert() if err != nil { - t.Fatalf("Unable to prepare tmpfile: %v", err) + t.Fatalf("unable to create cert: %v", err) } - defer os.Remove(tmp) + defer del() - info := TLSInfo{CertFile: tmp, KeyFile: tmp, CAFile: tmp} - info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, errors.New("fake")) + tlsinfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, errors.New("fake")) - if _, err = info.ServerConfig(); err == nil { + if _, err = tlsinfo.ServerConfig(); err == nil { t.Errorf("expected non-nil error from ServerConfig()") } - if _, err = info.ClientConfig(); err == nil { + if _, err = tlsinfo.ClientConfig(); err == nil { t.Errorf("expected non-nil error from ClientConfig()") } } func TestTLSInfoConfigFuncs(t *testing.T) { - tmp, err := createTempFile([]byte("XXX")) + tlsinfo, del, err := createSelfCert() if err != nil { - t.Fatalf("Unable to prepare tmpfile: %v", err) + t.Fatalf("unable to create cert: %v", err) } - defer os.Remove(tmp) + defer del() tests := []struct { info TLSInfo @@ -215,13 +214,13 @@ func TestTLSInfoConfigFuncs(t *testing.T) { wantCAs bool }{ { - info: TLSInfo{CertFile: tmp, KeyFile: tmp}, + info: TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile}, clientAuth: tls.NoClientCert, wantCAs: false, }, { - info: TLSInfo{CertFile: tmp, KeyFile: tmp, CAFile: tmp}, + info: TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile, CAFile: tlsinfo.CertFile}, clientAuth: tls.RequireAndVerifyClientCert, wantCAs: true, }, diff --git a/pkg/transport/listener_tls.go b/pkg/transport/listener_tls.go index fc19a221eed2..8dc5c9cd7325 100644 --- a/pkg/transport/listener_tls.go +++ b/pkg/transport/listener_tls.go @@ -107,7 +107,6 @@ func (l *tlsListener) acceptLoop() { } } } - select { case l.connc <- tlsConn: conn = nil