diff --git a/compatibility_read_deadline.go b/compatibility_read_deadline.go new file mode 100644 index 0000000..7fb3d9c --- /dev/null +++ b/compatibility_read_deadline.go @@ -0,0 +1,41 @@ +package httpproxy + +import ( + "net" + "time" +) + +// aLongTimeAgo is a non-zero time, far in the past, used for +// immediate cancellation of network operations. +// copies from http +var aLongTimeAgo = time.Unix(1, 0) + +// NewListenerCompatibilityReadDeadline this is a wrapper used to be compatible with +// the contents of ServerConn after wrapping it so that it can be hijacked properly. +// there is no effect if the content is not manipulated. +func NewListenerCompatibilityReadDeadline(listener net.Listener) net.Listener { + return listenerCompatibilityReadDeadline{listener} +} + +type listenerCompatibilityReadDeadline struct { + net.Listener +} + +func (w listenerCompatibilityReadDeadline) Accept() (net.Conn, error) { + c, err := w.Listener.Accept() + if err != nil { + return nil, err + } + return connCompatibilityReadDeadline{c}, nil +} + +type connCompatibilityReadDeadline struct { + net.Conn +} + +func (d connCompatibilityReadDeadline) SetReadDeadline(t time.Time) error { + if aLongTimeAgo == t { + t = time.Now().Add(time.Second) + } + return d.Conn.SetReadDeadline(t) +} diff --git a/compatibility_read_deadline_test.go b/compatibility_read_deadline_test.go new file mode 100644 index 0000000..fa43591 --- /dev/null +++ b/compatibility_read_deadline_test.go @@ -0,0 +1,97 @@ +package httpproxy + +import ( + "context" + "encoding/hex" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "testing" +) + +func TestNewListenerCompatibilityReadDeadline(t *testing.T) { + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "check", r.RequestURI) + })) + + listener, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + listener = newHexListener(listener) + listener = NewListenerCompatibilityReadDeadline(listener) + + s, err := NewSimpleServer("http://u:p@:0") + if err != nil { + t.Fatal(err) + } + + s.Listener = listener + s.Start(context.Background()) + defer s.Close() + + dial, err := NewDialer(s.ProxyURL()) + if err != nil { + t.Fatal(err) + } + dial.ProxyDial = func(ctx context.Context, network, address string) (net.Conn, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + conn = newHexConn(conn) + return conn, nil + } + cli := testServer.Client() + cli.Transport = &http.Transport{ + DialContext: dial.DialContext, + } + + resp, err := cli.Get(testServer.URL) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() +} + +func newHexListener(listener net.Listener) net.Listener { + return hexListener{ + Listener: listener, + } +} + +type hexListener struct { + net.Listener +} + +func (h hexListener) Accept() (net.Conn, error) { + conn, err := h.Listener.Accept() + if err != nil { + return nil, err + } + return newHexConn(conn), nil +} + +func newHexConn(conn net.Conn) net.Conn { + return hexConn{ + Conn: conn, + r: hex.NewDecoder(conn), + w: hex.NewEncoder(conn), + } +} + +type hexConn struct { + net.Conn + r io.Reader + w io.Writer +} + +func (h hexConn) Read(p []byte) (n int, err error) { + return h.r.Read(p) +} + +func (h hexConn) Write(p []byte) (n int, err error) { + return h.w.Write(p) +} diff --git a/simple_server.go b/simple_server.go index 5ee47cc..018f859 100644 --- a/simple_server.go +++ b/simple_server.go @@ -52,25 +52,29 @@ func NewSimpleServer(addr string) (*SimpleServer, error) { // Run the server func (s *SimpleServer) Run(ctx context.Context) error { var listenConfig net.ListenConfig - listener, err := listenConfig.Listen(ctx, s.Network, s.Address) - if err != nil { - return err + if s.Listener == nil { + listener, err := listenConfig.Listen(ctx, s.Network, s.Address) + if err != nil { + return err + } + s.Listener = NewListenerCompatibilityReadDeadline(listener) } - s.Listener = listener - s.Address = listener.Addr().String() - return s.Server.Serve(listener) + s.Address = s.Listener.Addr().String() + return s.Server.Serve(s.Listener) } // Start the server func (s *SimpleServer) Start(ctx context.Context) error { var listenConfig net.ListenConfig - listener, err := listenConfig.Listen(ctx, s.Network, s.Address) - if err != nil { - return err + if s.Listener == nil { + listener, err := listenConfig.Listen(ctx, s.Network, s.Address) + if err != nil { + return err + } + s.Listener = NewListenerCompatibilityReadDeadline(listener) } - s.Listener = listener - s.Address = listener.Addr().String() - go s.Server.Serve(listener) + s.Address = s.Listener.Addr().String() + go s.Server.Serve(s.Listener) return nil }