Skip to content

Commit

Permalink
Merge pull request #36 from projectdiscovery/dev
Browse files Browse the repository at this point in the history
0.0.8
  • Loading branch information
ehsandeep authored Feb 28, 2022
2 parents cdf6d96 + f89a181 commit 391d034
Show file tree
Hide file tree
Showing 18 changed files with 300 additions and 120 deletions.
16 changes: 13 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ func (c *Client) DoRawWithOptions(method, url, uripath string, headers map[strin
return c.do(method, url, uripath, headers, body, redirectstatus, options)
}

func (c *Client) getConn(protocol, host string, options Options) (Conn, error) {
if options.Proxy != "" {
return c.dialer.DialWithProxy(protocol, host, c.Options.Proxy, c.Options.ProxyDialTimeout)
}
if options.Timeout < 0 {
options.Timeout = 0
}
return c.dialer.DialTimeout(protocol, host, options.Timeout)
}

func (c *Client) do(method, url, uripath string, headers map[string][]string, body io.Reader, redirectstatus *RedirectStatus, options Options) (*http.Response, error) {
protocol := "http"
if strings.HasPrefix(strings.ToLower(url), "https://") {
Expand Down Expand Up @@ -137,7 +147,7 @@ func (c *Client) do(method, url, uripath string, headers map[string][]string, bo
protocol = "https"
}

conn, err := c.dialer.Dial(protocol, host)
conn, err := c.getConn(protocol, host, options)
if err != nil {
return nil, err
}
Expand All @@ -148,13 +158,13 @@ func (c *Client) do(method, url, uripath string, headers map[string][]string, bo

// set timeout if any
if options.Timeout > 0 {
conn.SetDeadline(time.Now().Add(options.Timeout))
_ = conn.SetDeadline(time.Now().Add(options.Timeout))
}

if err := conn.WriteRequest(req); err != nil {
return nil, err
}
resp, err := conn.ReadResponse()
resp, err := conn.ReadResponse(options.ForceReadAllBody)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ const readerBuffer = 4096
// HTTP but connection pooling is expected to be handled at a higher layer.
type Client interface {
WriteRequest(*Request) error
ReadResponse() (*Response, error)
ReadResponse(forceReadAll bool) (*Response, error)
}

// NewClient returns a Client implementation which uses rw to communicate.
Expand Down Expand Up @@ -122,7 +122,7 @@ func (c *client) WriteRequest(req *Request) error {
}

// ReadResponse unmarshalls a HTTP response.
func (c *client) ReadResponse() (*Response, error) {
func (c *client) ReadResponse(forceReadAll bool) (*Response, error) {
version, code, msg, err := c.ReadStatusLine()
var headers []Header
if err != nil {
Expand All @@ -148,7 +148,7 @@ func (c *client) ReadResponse() (*Response, error) {
Headers: headers,
Body: c.ReadBody(),
}
if l := resp.ContentLength(); l >= 0 {
if l := resp.ContentLength(); l >= 0 && !forceReadAll {
resp.Body = io.LimitReader(resp.Body, l)
} else if resp.TransferEncoding() == "chunked" {
resp.Body = httputil.NewChunkedReader(resp.Body)
Expand Down
1 change: 0 additions & 1 deletion client/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ type Status struct {
Reason string
}

var invalidStatus Status

func (s Status) String() string { return fmt.Sprintf("%d %s", s.Code, s.Reason) }

Expand Down
50 changes: 50 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package rawhttp

import (
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/julienschmidt/httprouter"
"github.com/projectdiscovery/stringsutil"
)

func getTestHttpServer(timeout time.Duration) *httptest.Server {
var ts *httptest.Server
router := httprouter.New()
router.GET("/rawhttp", httprouter.Handle(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) {
time.Sleep(timeout)
}))
ts = httptest.NewServer(router)
return ts
}

// run with go test -timeout 45s -run ^TestDialDefaultTimeout$ github.com/projectdiscovery/rawhttp
func TestDialDefaultTimeout(t *testing.T) {
timeout := 30 * time.Second
ts := getTestHttpServer(45 * time.Second)
defer ts.Close()

startTime := time.Now()
client := NewClient(DefaultOptions)
_, err := client.DoRaw("GET", ts.URL, "/rawhttp", nil, nil)
if !stringsutil.ContainsAny(err.Error(), "i/o timeout") || time.Now().Before(startTime.Add(timeout)) {
t.Error("default timeout error")
}
}

func TestDialWithCustomTimeout(t *testing.T) {
timeout := 5 * time.Second
ts := getTestHttpServer(10 * time.Second)
defer ts.Close()

startTime := time.Now()
client := NewClient(DefaultOptions)
options := DefaultOptions
options.Timeout = timeout
_, err := client.DoRawWithOptions("GET", ts.URL, "/rawhttp", nil, nil, options)
if !stringsutil.ContainsAny(err.Error(), "i/o timeout") || time.Now().Before(startTime.Add(timeout)) {
t.Error("custom timeout error")
}
}
7 changes: 5 additions & 2 deletions clientpipeline/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@ import (
const DefaultMaxConnsPerHost = 512
const DefaultMaxIdleConnDuration = 10 * time.Second
const DefaultMaxIdemponentCallAttempts = 5
const defaultReadBufferSize = 4096
const defaultWriteBufferSize = 4096

type DialFunc func(addr string) (net.Conn, error)
type RetryIfFunc func(request *Request) bool

var errorChPool sync.Pool

var (
ErrNoFreeConns = errors.New("no free connections available to host")
ErrConnectionClosed = errors.New("the server closed connection before returning the first response byte. " +
"Make sure the server returns 'Connection: close' response header before closing the connection")
// ErrGetOnly is returned when server expects only GET requests,
// but some other type of request came (Server.GetOnly option is true).
ErrGetOnly = errors.New("non-GET request received")
)

type timeoutError struct {
Expand Down
81 changes: 0 additions & 81 deletions clientpipeline/http.go

This file was deleted.

4 changes: 2 additions & 2 deletions clientpipeline/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (resp *Response) Read(r *bufio.Reader) error {
}
if key == "" {
// empty header values are valid, rfc 2616 s4.2.
err = errors.New("invalid header")
err = errors.New("invalid header") //nolint
break
}
headers = append(headers, Header{key, value})
Expand Down Expand Up @@ -222,7 +222,7 @@ func (resp *Response) ReadBody(r *bufio.Reader) io.Reader {
l := resp.ContentLength()
if l > 0 {
resp.body = make([]byte, l)
io.ReadFull(r, resp.body)
io.ReadFull(r, resp.body) //nolint

return bytes.NewReader(resp.body)
}
Expand Down
79 changes: 71 additions & 8 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,25 @@ package rawhttp

import (
"crypto/tls"
"fmt"
"io"
"net"
"net/url"
"strings"
"sync"
"time"

"github.com/projectdiscovery/rawhttp/client"
"github.com/projectdiscovery/rawhttp/proxy"
)

// Dialer can dial a remote HTTP server.
type Dialer interface {
// Dial dials a remote http server returning a Conn.
Dial(protocol, addr string) (Conn, error)
DialWithProxy(protocol, addr, proxyURL string, timeout time.Duration) (Conn, error)
// Dial dials a remote http server with timeout returning a Conn.
DialTimeout(protocol, addr string, timeout time.Duration) (Conn, error)
}

type dialer struct {
Expand All @@ -22,35 +29,91 @@ type dialer struct {
}

func (d *dialer) Dial(protocol, addr string) (Conn, error) {
return d.dialTimeout(protocol, addr, 0)
}

func (d *dialer) DialTimeout(protocol, addr string, timeout time.Duration) (Conn, error) {
return d.dialTimeout(protocol, addr, timeout)
}

func (d *dialer) dialTimeout(protocol, addr string, timeout time.Duration) (Conn, error) {
d.Lock()
if d.conns == nil {
d.conns = make(map[string][]Conn)
}
if c, ok := d.conns[addr]; ok {
if len(c) > 0 {
conn := c[0]
c[0], c = c[len(c)-1], c[:len(c)-1]
c[0] = c[len(c)-1]
d.Unlock()
return conn, nil
}
}
d.Unlock()
c, err := clientDial(protocol, addr)
c, err := clientDial(protocol, addr, timeout)
return &conn{
Client: client.NewClient(c),
Conn: c,
dialer: d,
}, err
}

func clientDial(protocol, addr string) (net.Conn, error) {
// http
if protocol == "http" {
return net.Dial("tcp", addr)
func (d *dialer) DialWithProxy(protocol, addr, proxyURL string, timeout time.Duration) (Conn, error) {
var c net.Conn
u, err := url.Parse(proxyURL)
if err != nil {
return nil, fmt.Errorf("unsupported proxy error: %w", err)
}
switch u.Scheme {
case "http":
c, err = proxy.HTTPDialer(proxyURL, timeout)(addr)
case "socks5", "socks5h":
c, err = proxy.Socks5Dialer(proxyURL, timeout)(addr)
default:
return nil, fmt.Errorf("unsupported proxy protocol: %s", proxyURL)
}
if err != nil {
return nil, fmt.Errorf("proxy error: %w", err)
}
if protocol == "https" {
if c, err = TlsHandshake(c, addr); err != nil {
return nil, fmt.Errorf("tls handshake error: %w", err)
}
}
return &conn{
Client: client.NewClient(c),
Conn: c,
dialer: d,
}, err
}

func clientDial(protocol, addr string, timeout time.Duration) (net.Conn, error) {
conn, err := net.DialTimeout("tcp", addr, timeout)
if protocol == "https" {
if conn, err = TlsHandshake(conn, addr); err != nil {
return nil, fmt.Errorf("tls handshake error: %w", err)
}
}
return conn, err
}

// https
return tls.Dial("tcp", addr, &tls.Config{InsecureSkipVerify: true})
// TlsHandshake tls handshake on a plain connection
func TlsHandshake(conn net.Conn, addr string) (net.Conn, error) {
colonPos := strings.LastIndex(addr, ":")
if colonPos == -1 {
colonPos = len(addr)
}
hostname := addr[:colonPos]

tlsConn := tls.Client(conn, &tls.Config{
InsecureSkipVerify: true,
ServerName: hostname,
})
if err := tlsConn.Handshake(); err != nil {
conn.Close()
return nil, err
}
return tlsConn, nil
}

// Conn is an interface implemented by a connection
Expand Down
8 changes: 5 additions & 3 deletions example/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ package main
import (
"fmt"
"net/http"
)

var i int
"github.com/projectdiscovery/gologger"
)

func headers(w http.ResponseWriter, req *http.Request) {
for name, headers := range req.Header {
Expand All @@ -17,5 +17,7 @@ func headers(w http.ResponseWriter, req *http.Request) {

func main() {
http.HandleFunc("/headers", headers)
http.ListenAndServe(":10000", nil)
if err := http.ListenAndServe(":10000", nil); err != nil {
gologger.Fatal().Msgf("Could not listen and serve: %s\n", err)
}
}
Loading

0 comments on commit 391d034

Please sign in to comment.