From b2e302fc908eed8d885433c867a09b2a15f8941f Mon Sep 17 00:00:00 2001 From: Shiming Zhang Date: Tue, 31 Jan 2023 23:24:23 +0800 Subject: [PATCH] Fix compatibility --- compatibility_read_deadline.go | 17 +++++++++++-- go.mod | 4 +-- go.sum | 8 +++--- proxy.go | 45 ++++++++++++++++++++++++++++------ tunnel.go | 1 - 5 files changed, 59 insertions(+), 16 deletions(-) diff --git a/compatibility_read_deadline.go b/compatibility_read_deadline.go index 7fb3d9c..0d25494 100644 --- a/compatibility_read_deadline.go +++ b/compatibility_read_deadline.go @@ -26,7 +26,20 @@ func (w listenerCompatibilityReadDeadline) Accept() (net.Conn, error) { if err != nil { return nil, err } - return connCompatibilityReadDeadline{c}, nil + return NewConnCompatibilityReadDeadline(c), nil +} + +// NewConnCompatibilityReadDeadline this is a wrapper used to be compatible with +// the net.Conn after wrapping it so that it can be hijacked properly. +// there is no effect if the content is not manipulated. +func NewConnCompatibilityReadDeadline(conn net.Conn) net.Conn { + if conn == nil { + return nil + } + if conn, ok := conn.(connCompatibilityReadDeadline); ok { + return conn + } + return connCompatibilityReadDeadline{conn} } type connCompatibilityReadDeadline struct { @@ -35,7 +48,7 @@ type connCompatibilityReadDeadline struct { func (d connCompatibilityReadDeadline) SetReadDeadline(t time.Time) error { if aLongTimeAgo == t { - t = time.Now().Add(time.Second) + t = time.Now().Add(1 * time.Second) } return d.Conn.SetReadDeadline(t) } diff --git a/go.mod b/go.mod index 575c775..3ed60ae 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,6 @@ module github.com/wzshiming/httpproxy go 1.18 -require golang.org/x/net v0.2.0 +require golang.org/x/net v0.5.0 -require golang.org/x/text v0.4.0 // indirect +require golang.org/x/text v0.6.0 // indirect diff --git a/go.sum b/go.sum index 7f4ee92..e6f7a4e 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,4 @@ -golang.org/x/net v0.2.0 h1:sZfSu1wtKLGlWI4ZZayP0ck9Y73K1ynO6gqzTdBVdPU= -golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= -golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw= +golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= +golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= +golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= diff --git a/proxy.go b/proxy.go index edb05f6..e18d88c 100644 --- a/proxy.go +++ b/proxy.go @@ -1,6 +1,7 @@ package httpproxy import ( + "bufio" "context" "fmt" "io" @@ -96,16 +97,13 @@ func (p *ProxyHandler) proxyConnect(w http.ResponseWriter, r *http.Request) { http.Error(w, e, http.StatusInternalServerError) return } + defer targetConn.Close() - if flusher, ok := w.(http.Flusher); ok { - flusher.Flush() - } else { - w.WriteHeader(http.StatusOK) - } + w.WriteHeader(http.StatusOK) - clientConn, _, err := hijacker.Hijack() + conn, rw, err := hijacker.Hijack() if err != nil { - e := err.Error() + e := fmt.Sprintf("hijack failed: %v", err) if p.Logger != nil { p.Logger.Println(e) } @@ -113,6 +111,8 @@ func (p *ProxyHandler) proxyConnect(w http.ResponseWriter, r *http.Request) { return } + clientConn := newBufConn(conn, rw) + var buf1, buf2 []byte if p.BytesPool != nil { buf1 = p.BytesPool.Get() @@ -151,3 +151,34 @@ func (p *ProxyHandler) proxyDial(ctx context.Context, network, address string) ( } return proxyDial(ctx, network, address) } + +func newBufConn(conn net.Conn, rw *bufio.ReadWriter) net.Conn { + rw.Flush() + if rw.Reader.Buffered() == 0 { + // If there's no buffered data to be read, + // we can just discard the bufio.ReadWriter. + return conn + } + return &bufConn{conn, rw.Reader} +} + +// bufConn wraps a net.Conn, but reads drain the bufio.Reader first. +type bufConn struct { + net.Conn + *bufio.Reader +} + +func (c *bufConn) Read(p []byte) (int, error) { + if c.Reader == nil { + return c.Conn.Read(p) + } + n := c.Reader.Buffered() + if n == 0 { + c.Reader = nil + return c.Conn.Read(p) + } + if n < len(p) { + p = p[:n] + } + return c.Reader.Read(p) +} diff --git a/tunnel.go b/tunnel.go index a0281cf..7f3eb0a 100644 --- a/tunnel.go +++ b/tunnel.go @@ -5,7 +5,6 @@ import ( "io" ) - // tunnel create tunnels for two io.ReadWriteCloser func tunnel(ctx context.Context, c1, c2 io.ReadWriteCloser, buf1, buf2 []byte) error { ctx, cancel := context.WithCancel(ctx)