diff --git a/client.go b/client.go index bb70166..487ba73 100644 --- a/client.go +++ b/client.go @@ -196,17 +196,17 @@ func (d *DialOption) Dial() (wsCon *Conn, err error) { } var conn net.Conn - begin := time.Now() hostName := hostname.GetHostName(d.u) - // conn, err := net.DialTimeout("tcp", d.u.Host /* TODO 加端号*/, d.dialTimeout) - dialFunc := net.Dial + dialFunc := net.DialTimeout if d.dialFunc != nil { dialInterface, err := d.dialFunc() if err != nil { return nil, err } - dialFunc = dialInterface.Dial + dialFunc = func(network, address string, timeout time.Duration) (net.Conn, error) { + return dialInterface.Dial(network, address) + } } if d.proxyFunc != nil { @@ -214,16 +214,14 @@ func (d *DialOption) Dial() (wsCon *Conn, err error) { if err != nil { return nil, err } - dialFunc = newhttpProxy(proxyURL, dialFunc).Dial + dialFunc = newhttpProxy(proxyURL, dialFunc).DialTimeout } - conn, err = dialFunc("tcp", hostName) + conn, err = dialFunc("tcp", hostName, d.dialTimeout) if err != nil { return nil, err } - dialDuration := time.Since(begin) - conn = d.tlsConn(conn) defer func() { if err != nil && conn != nil { @@ -232,18 +230,7 @@ func (d *DialOption) Dial() (wsCon *Conn, err error) { } }() - if to := d.dialTimeout - dialDuration; to > 0 { - if err = conn.SetDeadline(time.Now().Add(to)); err != nil { - return - } - } - - defer func() { - if err == nil { - err = conn.SetDeadline(time.Time{}) - } - }() - + err = conn.SetDeadline(time.Time{}) if err = req.Write(conn); err != nil { return } diff --git a/common_options_test.go b/common_options_test.go index 6f2ec96..6436b9b 100644 --- a/common_options_test.go +++ b/common_options_test.go @@ -2227,13 +2227,14 @@ func Test_CommonOption(t *testing.T) { t.Run("22.3.WithClientReadMaxMessage", func(t *testing.T) { var tsort testServerOptionReadTimeout - upgrade := NewUpgrade(WithServerCallback(&tsort), WithServerReadTimeout(time.Millisecond*60)) + upgrade := NewUpgrade() tsort.err = make(chan error, 1) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := upgrade.Upgrade(w, r) if err != nil { t.Error(err) } + time.Sleep(time.Second / 100) err = c.WriteMessage(Binary, bytes.Repeat([]byte("1"), 1025)) if err != nil { t.Error(err) @@ -2245,12 +2246,57 @@ func Test_CommonOption(t *testing.T) { defer ts.Close() url := strings.ReplaceAll(ts.URL, "http", "ws") - con, err := Dial(url, WithClientBufioParseMode(), WithClientReadMaxMessage(1<<10), WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { + con, err := Dial(url, WithClientCallback(&tsort), WithClientBufioParseMode(), WithClientReadMaxMessage(1<<10)) + if err != nil { + t.Error(err) + return + } + defer con.Close() + go func() { + _ = con.ReadLoop() + }() + + select { + case d := <-tsort.err: + if d == nil { + t.Errorf("got:nil, need:error\n") + } + case <-time.After(100 * time.Hour): + t.Errorf(" Test_ServerOption:WithServerReadMaxMessage timeout\n") + } + if atomic.LoadInt32(&tsort.run) != 1 { + t.Error("not run server:method fail") + } + }) + t.Run("22.4.WithClientReadMaxMessage-ParseWindows", func(t *testing.T) { + var tsort testServerOptionReadTimeout + + upgrade := NewUpgrade() + tsort.err = make(chan error, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := upgrade.Upgrade(w, r) + if err != nil { + t.Error(err) + } + time.Sleep(time.Second / 100) + err = c.WriteMessage(Binary, bytes.Repeat([]byte("1"), 1025)) + if err != nil { + t.Error(err) + return + } + c.StartReadLoop() })) + + defer ts.Close() + + url := strings.ReplaceAll(ts.URL, "http", "ws") + con, err := Dial(url, WithClientCallback(&tsort), WithClientReadMaxMessage(1<<10)) if err != nil { t.Error(err) + return } defer con.Close() + con.StartReadLoop() select { case d := <-tsort.err: diff --git a/config.go b/config.go index 82ec1a5..2383476 100644 --- a/config.go +++ b/config.go @@ -27,10 +27,16 @@ import ( var ErrDialFuncAndProxyFunc = errors.New("dialFunc and proxyFunc can't be set at the same time") +// 握手 type Dialer interface { Dial(network, addr string) (c net.Conn, err error) } +// 带超时时间的握手 +type DialerTimeout interface { + DialTimeout(network, addr string, timeout time.Duration) (c net.Conn, err error) +} + // Config的配置,有两个种用法 // 一种是声明一个全局的配置,后面不停使用。 // 另外一种是局部声明一个配置,然后使用WithXXX函数设置配置 diff --git a/conn.go b/conn.go index e15e906..0e2cceb 100644 --- a/conn.go +++ b/conn.go @@ -208,10 +208,11 @@ func (c *Conn) readDataFromNet(headArray *[enum.MaxFrameHeaderSize]byte, bufioPa } } else { r := io.Reader(c.br) + var lr io.Reader if c.readMaxMessage > 0 { - r = limitreader.NewLimitReader(c.br, c.readMaxMessage) + lr = limitreader.NewLimitReader(c.br, c.readMaxMessage) } - f, err = frame.ReadFrameFromReaderV2(r, headArray, bufioPayload) + f, err = frame.ReadFrameFromReaderV3(r, lr, headArray, bufioPayload) } if err != nil { c.writeAndMaybeOnClose(err) diff --git a/conn_test.go b/conn_test.go index 2fa01b0..a4a7561 100644 --- a/conn_test.go +++ b/conn_test.go @@ -608,7 +608,7 @@ func TestFragmentFrame(t *testing.T) { select { case <-data: atomic.AddInt32(&run, 1) - case <-time.After(500 * time.Hour): + case <-time.After(500 * time.Millisecond): } if atomic.LoadInt32(&run) != 1 { diff --git a/go.mod b/go.mod index 1269a2a..d7a5365 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/antlabs/quickws go 1.21 require ( - github.com/antlabs/wsutil v0.1.10 + github.com/antlabs/wsutil v0.1.11 golang.org/x/net v0.23.0 ) diff --git a/go.sum b/go.sum index 776990d..b7d7e6d 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -github.com/antlabs/wsutil v0.1.10 h1:86p67dG8/iiQ+yZrHVl73OPHGnXfXopFSU0w84fLOdE= -github.com/antlabs/wsutil v0.1.10/go.mod h1:Pk7xYOw3o5iEB6ukiOu+2uJMLYeMVVjJLazFD3okI2A= +github.com/antlabs/wsutil v0.1.11 h1:bIVZ3Hxdq5ByZKu5OXL/cMtanEw6YlxdtUDiySI77Q0= +github.com/antlabs/wsutil v0.1.11/go.mod h1:Pk7xYOw3o5iEB6ukiOu+2uJMLYeMVVjJLazFD3okI2A= github.com/klauspost/compress v1.17.8 h1:YcnTYrq7MikUT7k0Yb5eceMmALQPYBW/Xltxn0NAMnU= github.com/klauspost/compress v1.17.8/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= diff --git a/proxy.go b/proxy.go index 9aaa8a8..ee8c74d 100644 --- a/proxy.go +++ b/proxy.go @@ -20,31 +20,33 @@ import ( "net" "net/http" "net/url" + "time" "github.com/antlabs/wsutil/hostname" ) type ( - dialFunc func(network, addr string) (c net.Conn, err error) + dialFunc func(network, addr string, timeout time.Duration) (c net.Conn, err error) httpProxy struct { - proxyAddr *url.URL - dial func(network, addr string) (c net.Conn, err error) + proxyAddr *url.URL + dialTimeout func(network, addr string, timeout time.Duration) (c net.Conn, err error) + timeout time.Duration } ) -var _ Dialer = (*httpProxy)(nil) +var _ DialerTimeout = (*httpProxy)(nil) func newhttpProxy(u *url.URL, dial dialFunc) *httpProxy { - return &httpProxy{proxyAddr: u, dial: dial} + return &httpProxy{proxyAddr: u, dialTimeout: dial} } -func (h *httpProxy) Dial(network, addr string) (c net.Conn, err error) { +func (h *httpProxy) DialTimeout(network, addr string, timeout time.Duration) (c net.Conn, err error) { if h.proxyAddr == nil { - return h.dial(network, addr) + return h.dialTimeout(network, addr, h.timeout) } hostName := hostname.GetHostName(h.proxyAddr) - c, err = h.dial(network, hostName) + c, err = h.dialTimeout(network, hostName, h.timeout) if err != nil { return nil, err } diff --git a/proxy_test.go b/proxy_test.go index 610c759..1895d6e 100644 --- a/proxy_test.go +++ b/proxy_test.go @@ -22,6 +22,7 @@ import ( "net/url" "strings" "testing" + "time" ) type testServer struct { @@ -133,7 +134,7 @@ func Test_Proxy(t *testing.T) { func Test_httpProxy_Dial(t *testing.T) { type fields struct { proxyAddr *url.URL - dial func(network, addr string) (c net.Conn, err error) + dial func(network, addr string, timeout time.Duration) (c net.Conn, err error) } type args struct { network string @@ -146,12 +147,12 @@ func Test_httpProxy_Dial(t *testing.T) { wantC net.Conn wantErr bool }{ - // TODO: Add test cases. + // 0 { name: "No proxy address", fields: fields{ proxyAddr: nil, - dial: func(network, addr string) (c net.Conn, err error) { + dial: func(network, addr string, timeout time.Duration) (c net.Conn, err error) { // Simulate successful dialing return &net.TCPConn{}, errors.New("fail") }, @@ -163,11 +164,12 @@ func Test_httpProxy_Dial(t *testing.T) { wantC: &net.TCPConn{}, wantErr: true, }, + // 1 { name: "Proxy address", fields: fields{ proxyAddr: &url.URL{Host: "1.2.3:8080", User: url.UserPassword("user", "password")}, - dial: func(network, addr string) (c net.Conn, err error) { + dial: func(network, addr string, timeout time.Duration) (c net.Conn, err error) { // Simulate successful dialing return &net.TCPConn{}, errors.New("fail") }, @@ -179,11 +181,12 @@ func Test_httpProxy_Dial(t *testing.T) { wantC: &net.TCPConn{}, wantErr: true, }, + // 2 { name: "Proxy address", fields: fields{ proxyAddr: &url.URL{Host: "1.2.3:8080", User: url.UserPassword("user", "password")}, - dial: func(network, addr string) (c net.Conn, err error) { + dial: func(network, addr string, timeout time.Duration) (c net.Conn, err error) { // Simulate successful dialing return &net.TCPConn{}, nil }, @@ -193,17 +196,16 @@ func Test_httpProxy_Dial(t *testing.T) { addr: "a.b.c:80", }, wantC: &net.TCPConn{}, - wantErr: true, + wantErr: false, }, } for i, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := &httpProxy{ - proxyAddr: tt.fields.proxyAddr, - dial: tt.fields.dial, + proxyAddr: tt.fields.proxyAddr, + dialTimeout: tt.fields.dial, } - _, err := h.Dial(tt.args.network, tt.args.addr) - // gotC, err := h.Dial(tt.args.network, tt.args.addr) + _, err := h.dialTimeout(tt.args.network, tt.args.addr, 0) if (err != nil) != tt.wantErr { t.Errorf("index:%d, httpProxy.Dial() error = %v, wantErr %v", i, err, tt.wantErr) return