Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

0.0.8 #36

Merged
merged 17 commits into from
Feb 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ func (c *Client) DoRawWithOptions(method, url, uripath string, headers map[strin
return c.do(method, url, uripath, headers, body, redirectstatus, options)
}

func (c *Client) getConn(protocol, host string, options Options) (Conn, error) {
if options.Proxy != "" {
return c.dialer.DialWithProxy(protocol, host, c.Options.Proxy, c.Options.ProxyDialTimeout)
}
if options.Timeout < 0 {
options.Timeout = 0
}
return c.dialer.DialTimeout(protocol, host, options.Timeout)
}

func (c *Client) do(method, url, uripath string, headers map[string][]string, body io.Reader, redirectstatus *RedirectStatus, options Options) (*http.Response, error) {
protocol := "http"
if strings.HasPrefix(strings.ToLower(url), "https://") {
Expand Down Expand Up @@ -137,7 +147,7 @@ func (c *Client) do(method, url, uripath string, headers map[string][]string, bo
protocol = "https"
}

conn, err := c.dialer.Dial(protocol, host)
conn, err := c.getConn(protocol, host, options)
if err != nil {
return nil, err
}
Expand All @@ -148,13 +158,13 @@ func (c *Client) do(method, url, uripath string, headers map[string][]string, bo

// set timeout if any
if options.Timeout > 0 {
conn.SetDeadline(time.Now().Add(options.Timeout))
_ = conn.SetDeadline(time.Now().Add(options.Timeout))
}

if err := conn.WriteRequest(req); err != nil {
return nil, err
}
resp, err := conn.ReadResponse()
resp, err := conn.ReadResponse(options.ForceReadAllBody)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ const readerBuffer = 4096
// HTTP but connection pooling is expected to be handled at a higher layer.
type Client interface {
WriteRequest(*Request) error
ReadResponse() (*Response, error)
ReadResponse(forceReadAll bool) (*Response, error)
}

// NewClient returns a Client implementation which uses rw to communicate.
Expand Down Expand Up @@ -122,7 +122,7 @@ func (c *client) WriteRequest(req *Request) error {
}

// ReadResponse unmarshalls a HTTP response.
func (c *client) ReadResponse() (*Response, error) {
func (c *client) ReadResponse(forceReadAll bool) (*Response, error) {
version, code, msg, err := c.ReadStatusLine()
var headers []Header
if err != nil {
Expand All @@ -148,7 +148,7 @@ func (c *client) ReadResponse() (*Response, error) {
Headers: headers,
Body: c.ReadBody(),
}
if l := resp.ContentLength(); l >= 0 {
if l := resp.ContentLength(); l >= 0 && !forceReadAll {
resp.Body = io.LimitReader(resp.Body, l)
} else if resp.TransferEncoding() == "chunked" {
resp.Body = httputil.NewChunkedReader(resp.Body)
Expand Down
1 change: 0 additions & 1 deletion client/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ type Status struct {
Reason string
}

var invalidStatus Status

func (s Status) String() string { return fmt.Sprintf("%d %s", s.Code, s.Reason) }

Expand Down
50 changes: 50 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package rawhttp

import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/julienschmidt/httprouter"
"github.com/projectdiscovery/stringsutil"
)

func getTestHttpServer(timeout time.Duration) *httptest.Server {
var ts *httptest.Server
router := httprouter.New()
router.GET("/rawhttp", httprouter.Handle(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
time.Sleep(timeout)
}))
ts = httptest.NewServer(router)
return ts
}

// run with go test -timeout 45s -run ^TestDialDefaultTimeout$ github.com/projectdiscovery/rawhttp
func TestDialDefaultTimeout(t *testing.T) {
timeout := 30 * time.Second
ts := getTestHttpServer(45 * time.Second)
defer ts.Close()

startTime := time.Now()
client := NewClient(DefaultOptions)
_, err := client.DoRaw("GET", ts.URL, "/rawhttp", nil, nil)
if !stringsutil.ContainsAny(err.Error(), "i/o timeout") || time.Now().Before(startTime.Add(timeout)) {
t.Error("default timeout error")
}
}

func TestDialWithCustomTimeout(t *testing.T) {
timeout := 5 * time.Second
ts := getTestHttpServer(10 * time.Second)
defer ts.Close()

startTime := time.Now()
client := NewClient(DefaultOptions)
options := DefaultOptions
options.Timeout = timeout
_, err := client.DoRawWithOptions("GET", ts.URL, "/rawhttp", nil, nil, options)
if !stringsutil.ContainsAny(err.Error(), "i/o timeout") || time.Now().Before(startTime.Add(timeout)) {
t.Error("custom timeout error")
}
}
7 changes: 5 additions & 2 deletions clientpipeline/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@ import (
const DefaultMaxConnsPerHost = 512
const DefaultMaxIdleConnDuration = 10 * time.Second
const DefaultMaxIdemponentCallAttempts = 5
const defaultReadBufferSize = 4096
const defaultWriteBufferSize = 4096

type DialFunc func(addr string) (net.Conn, error)
type RetryIfFunc func(request *Request) bool

var errorChPool sync.Pool

var (
ErrNoFreeConns = errors.New("no free connections available to host")
ErrConnectionClosed = errors.New("the server closed connection before returning the first response byte. " +
"Make sure the server returns 'Connection: close' response header before closing the connection")
// ErrGetOnly is returned when server expects only GET requests,
// but some other type of request came (Server.GetOnly option is true).
ErrGetOnly = errors.New("non-GET request received")
)

type timeoutError struct {
Expand Down
81 changes: 0 additions & 81 deletions clientpipeline/http.go

This file was deleted.

4 changes: 2 additions & 2 deletions clientpipeline/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (resp *Response) Read(r *bufio.Reader) error {
}
if key == "" {
// empty header values are valid, rfc 2616 s4.2.
err = errors.New("invalid header")
err = errors.New("invalid header") //nolint
break
}
headers = append(headers, Header{key, value})
Expand Down Expand Up @@ -222,7 +222,7 @@ func (resp *Response) ReadBody(r *bufio.Reader) io.Reader {
l := resp.ContentLength()
if l > 0 {
resp.body = make([]byte, l)
io.ReadFull(r, resp.body)
io.ReadFull(r, resp.body) //nolint

return bytes.NewReader(resp.body)
}
Expand Down
79 changes: 71 additions & 8 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,25 @@ package rawhttp

import (
"crypto/tls"
"fmt"
"io"
"net"
"net/url"
"strings"
"sync"
"time"

"github.com/projectdiscovery/rawhttp/client"
"github.com/projectdiscovery/rawhttp/proxy"
)

// Dialer can dial a remote HTTP server.
type Dialer interface {
// Dial dials a remote http server returning a Conn.
Dial(protocol, addr string) (Conn, error)
DialWithProxy(protocol, addr, proxyURL string, timeout time.Duration) (Conn, error)
// Dial dials a remote http server with timeout returning a Conn.
DialTimeout(protocol, addr string, timeout time.Duration) (Conn, error)
}

type dialer struct {
Expand All @@ -22,35 +29,91 @@ type dialer struct {
}

func (d *dialer) Dial(protocol, addr string) (Conn, error) {
return d.dialTimeout(protocol, addr, 0)
}

func (d *dialer) DialTimeout(protocol, addr string, timeout time.Duration) (Conn, error) {
return d.dialTimeout(protocol, addr, timeout)
}

func (d *dialer) dialTimeout(protocol, addr string, timeout time.Duration) (Conn, error) {
d.Lock()
if d.conns == nil {
d.conns = make(map[string][]Conn)
}
if c, ok := d.conns[addr]; ok {
if len(c) > 0 {
conn := c[0]
c[0], c = c[len(c)-1], c[:len(c)-1]
c[0] = c[len(c)-1]
d.Unlock()
return conn, nil
}
}
d.Unlock()
c, err := clientDial(protocol, addr)
c, err := clientDial(protocol, addr, timeout)
return &conn{
Client: client.NewClient(c),
Conn: c,
dialer: d,
}, err
}

func clientDial(protocol, addr string) (net.Conn, error) {
// http
if protocol == "http" {
return net.Dial("tcp", addr)
func (d *dialer) DialWithProxy(protocol, addr, proxyURL string, timeout time.Duration) (Conn, error) {
var c net.Conn
u, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("unsupported proxy error: %w", err)
}
switch u.Scheme {
case "http":
c, err = proxy.HTTPDialer(proxyURL, timeout)(addr)
case "socks5", "socks5h":
c, err = proxy.Socks5Dialer(proxyURL, timeout)(addr)
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", proxyURL)
}
if err != nil {
return nil, fmt.Errorf("proxy error: %w", err)
}
if protocol == "https" {
if c, err = TlsHandshake(c, addr); err != nil {
return nil, fmt.Errorf("tls handshake error: %w", err)
}
}
return &conn{
Client: client.NewClient(c),
Conn: c,
dialer: d,
}, err
}

func clientDial(protocol, addr string, timeout time.Duration) (net.Conn, error) {
conn, err := net.DialTimeout("tcp", addr, timeout)
if protocol == "https" {
if conn, err = TlsHandshake(conn, addr); err != nil {
return nil, fmt.Errorf("tls handshake error: %w", err)
}
}
return conn, err
}

// https
return tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: true})
// TlsHandshake tls handshake on a plain connection
func TlsHandshake(conn net.Conn, addr string) (net.Conn, error) {
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]

tlsConn := tls.Client(conn, &tls.Config{
InsecureSkipVerify: true,
ServerName: hostname,
})
if err := tlsConn.Handshake(); err != nil {
conn.Close()
return nil, err
}
return tlsConn, nil
}

// Conn is an interface implemented by a connection
Expand Down
8 changes: 5 additions & 3 deletions example/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package main
import (
"fmt"
"net/http"
)

var i int
"github.com/projectdiscovery/gologger"
)

func headers(w http.ResponseWriter, req *http.Request) {
for name, headers := range req.Header {
Expand All @@ -17,5 +17,7 @@ func headers(w http.ResponseWriter, req *http.Request) {

func main() {
http.HandleFunc("/headers", headers)
http.ListenAndServe(":10000", nil)
if err := http.ListenAndServe(":10000", nil); err != nil {
gologger.Fatal().Msgf("Could not listen and serve: %s\n", err)
}
}
Loading