Skip to content

Commit

Permalink
Compatibility with ReadDeadline
Browse files Browse the repository at this point in the history
  • Loading branch information
wzshiming committed Nov 18, 2022
1 parent 4e7640f commit 7389afb
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 12 deletions.
41 changes: 41 additions & 0 deletions compatibility_read_deadline.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package httpproxy

import (
"net"
"time"
)

// aLongTimeAgo is a non-zero time, far in the past, used for
// immediate cancellation of network operations.
// copies from http
var aLongTimeAgo = time.Unix(1, 0)

// NewListenerCompatibilityReadDeadline this is a wrapper used to be compatible with
// the contents of ServerConn after wrapping it so that it can be hijacked properly.
// there is no effect if the content is not manipulated.
func NewListenerCompatibilityReadDeadline(listener net.Listener) net.Listener {
return listenerCompatibilityReadDeadline{listener}
}

type listenerCompatibilityReadDeadline struct {
net.Listener
}

func (w listenerCompatibilityReadDeadline) Accept() (net.Conn, error) {
c, err := w.Listener.Accept()
if err != nil {
return nil, err
}
return connCompatibilityReadDeadline{c}, nil
}

type connCompatibilityReadDeadline struct {
net.Conn
}

func (d connCompatibilityReadDeadline) SetReadDeadline(t time.Time) error {
if aLongTimeAgo == t {
t = time.Now().Add(time.Second)
}
return d.Conn.SetReadDeadline(t)
}
97 changes: 97 additions & 0 deletions compatibility_read_deadline_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package httpproxy

import (
"context"
"encoding/hex"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"testing"
)

func TestNewListenerCompatibilityReadDeadline(t *testing.T) {
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "check", r.RequestURI)
}))

listener, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
listener = newHexListener(listener)
listener = NewListenerCompatibilityReadDeadline(listener)

s, err := NewSimpleServer("http://u:p@:0")
if err != nil {
t.Fatal(err)
}

s.Listener = listener
s.Start(context.Background())
defer s.Close()

dial, err := NewDialer(s.ProxyURL())
if err != nil {
t.Fatal(err)
}
dial.ProxyDial = func(ctx context.Context, network, address string) (net.Conn, error) {
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
conn = newHexConn(conn)
return conn, nil
}
cli := testServer.Client()
cli.Transport = &http.Transport{
DialContext: dial.DialContext,
}

resp, err := cli.Get(testServer.URL)
if err != nil {
t.Fatal(err)
}
resp.Body.Close()
}

func newHexListener(listener net.Listener) net.Listener {
return hexListener{
Listener: listener,
}
}

type hexListener struct {
net.Listener
}

func (h hexListener) Accept() (net.Conn, error) {
conn, err := h.Listener.Accept()
if err != nil {
return nil, err
}
return newHexConn(conn), nil
}

func newHexConn(conn net.Conn) net.Conn {
return hexConn{
Conn: conn,
r: hex.NewDecoder(conn),
w: hex.NewEncoder(conn),
}
}

type hexConn struct {
net.Conn
r io.Reader
w io.Writer
}

func (h hexConn) Read(p []byte) (n int, err error) {
return h.r.Read(p)
}

func (h hexConn) Write(p []byte) (n int, err error) {
return h.w.Write(p)
}
28 changes: 16 additions & 12 deletions simple_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,29 @@ func NewSimpleServer(addr string) (*SimpleServer, error) {
// Run the server
func (s *SimpleServer) Run(ctx context.Context) error {
var listenConfig net.ListenConfig
listener, err := listenConfig.Listen(ctx, s.Network, s.Address)
if err != nil {
return err
if s.Listener == nil {
listener, err := listenConfig.Listen(ctx, s.Network, s.Address)
if err != nil {
return err
}
s.Listener = NewListenerCompatibilityReadDeadline(listener)
}
s.Listener = listener
s.Address = listener.Addr().String()
return s.Server.Serve(listener)
s.Address = s.Listener.Addr().String()
return s.Server.Serve(s.Listener)
}

// Start the server
func (s *SimpleServer) Start(ctx context.Context) error {
var listenConfig net.ListenConfig
listener, err := listenConfig.Listen(ctx, s.Network, s.Address)
if err != nil {
return err
if s.Listener == nil {
listener, err := listenConfig.Listen(ctx, s.Network, s.Address)
if err != nil {
return err
}
s.Listener = NewListenerCompatibilityReadDeadline(listener)
}
s.Listener = listener
s.Address = listener.Addr().String()
go s.Server.Serve(listener)
s.Address = s.Listener.Addr().String()
go s.Server.Serve(s.Listener)
return nil
}

Expand Down

0 comments on commit 7389afb

Please sign in to comment.