Skip to content

Commit

Permalink
Merge pull request #301 from projectdiscovery/bugfix-semi-permanent
Browse files Browse the repository at this point in the history
Handling runtime ip block
  • Loading branch information
Mzack9999 authored Jun 11, 2024
2 parents 026588d + d597162 commit 427bffb
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 93 deletions.
36 changes: 25 additions & 11 deletions fastdialer/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/json"
"net"
"strings"
"sync/atomic"

"github.com/Mzack9999/gcache"
gounit "github.com/docker/go-units"
Expand Down Expand Up @@ -53,14 +54,15 @@ type Dialer struct {
// memory typed cache
mDnsCache gcache.Cache[string, *retryabledns.DNSData]
// memory/disk untyped ([]byte) cache
hmDnsCache *hybrid.HybridMap
hostsFileData *hybrid.HybridMap
dialerHistory *hybrid.HybridMap
dialerTLSData *hybrid.HybridMap
dialer *net.Dialer
proxyDialer *proxy.Dialer
networkpolicy *networkpolicy.NetworkPolicy
dialCache gcache.Cache[string, *utils.DialWrap]
hmDnsCache *hybrid.HybridMap
hostsFileData *hybrid.HybridMap
dialerHistory *hybrid.HybridMap
dialerTLSData *hybrid.HybridMap
dialer *net.Dialer
proxyDialer *proxy.Dialer
networkpolicy *networkpolicy.NetworkPolicy
dialCache gcache.Cache[string, *utils.DialWrap]
dialTimeoutErrors gcache.Cache[string, *atomic.Uint32]
}

// NewDialer instance
Expand Down Expand Up @@ -158,7 +160,7 @@ func NewDialer(options Options) (*Dialer, error) {
if err != nil {
return nil, err
}
return &Dialer{
d := &Dialer{
dnsclient: dnsclient,
mDnsCache: dnsCache,
hmDnsCache: hmDnsCache,
Expand All @@ -170,7 +172,13 @@ func NewDialer(options Options) (*Dialer, error) {
options: &options,
networkpolicy: np,
dialCache: gcache.New[string, *utils.DialWrap](MaxDialCacheSize).Build(),
}, nil
}

if options.MaxTemporaryErrors > 0 && options.MaxTemporaryToPermanentDuration > 0 {
d.dialTimeoutErrors = gcache.New[string, *atomic.Uint32](MaxDialCacheSize).Expiration(options.MaxTemporaryToPermanentDuration).Build()
}

return d, nil
}

// Dial function compatible with net/http
Expand Down Expand Up @@ -273,7 +281,13 @@ func (d *Dialer) Close() {
if d.options.WithTLSData {
d.dialerTLSData.Close()
}
// donot close hosts file as it is meant to be shared
if d.dialCache != nil {
d.dialCache.Purge()
}
if d.dialTimeoutErrors != nil {
d.dialTimeoutErrors.Purge()
}
// do not close hosts file as it is meant to be shared
}

// GetDialedIP returns the ip dialed by the HTTP client
Expand Down
79 changes: 47 additions & 32 deletions fastdialer/dialer_private.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net"
"os"
"strings"
"sync/atomic"
"time"

"github.com/projectdiscovery/fastdialer/fastdialer/ja3/impersonate"
Expand Down Expand Up @@ -46,6 +47,20 @@ type dialOptions struct {
hostname string
}

// connHash returns a hash of the connection
func (d *dialOptions) connHash() string {
return fmt.Sprintf("%s-%s", d.network, d.address)
}

// logAddress returns the address to be logged in case of error
func (d *dialOptions) logAddress() string {
logAddress := d.hostname
if logAddress == "" {
logAddress = d.ips[0]
}
return net.JoinHostPort(logAddress, d.port)
}

func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, err error) {
// add global timeout to context
ctx, cancel := context.WithTimeoutCause(ctx, d.options.DialerTimeout, ErrDialTimeout)
Expand All @@ -54,9 +69,8 @@ func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, er
var hostname, port, fixedIP string
var IPS []string
// check if this is present in cache
dw, _ := d.dialCache.GetIFPresent(connHash(opts.network, opts.address))
if dw == nil {

dw, err := d.dialCache.GetIFPresent(opts.connHash())
if err != nil || dw == nil {
if strings.HasPrefix(opts.address, "[") {
closeBracketIndex := strings.Index(opts.address, "]")
if closeBracketIndex == -1 {
Expand Down Expand Up @@ -145,21 +159,21 @@ func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, er
// only cache it if below conditions are met
// 1. it is already not present
// 2. it is a domain and not ip
// 3. it has more than 1 valid ip
// 3. it has at least 1 valid ip
// 4. proxy dialer is not set

dw, err = utils.NewDialWrap(d.dialer, IPS, opts.network, opts.address, opts.port)
if err != nil {
return nil, errkit.Wrap(err, "could not create dialwrap")
}
if err = d.dialCache.Set(connHash(opts.network, opts.address), dw); err != nil {
if err = d.dialCache.Set(opts.connHash(), dw); err != nil {
return nil, errkit.Wrap(err, "could not set dialwrap")
}
}
if dw != nil {
finalDialer = dw
// when using dw ip , network , port etc are preset
// so get any one of them to avoid breaking furthur logic
// so get any one of them to avoid breaking further logic
ip, port := dw.Address()
opts.ips = []string{ip}
opts.port = port
Expand Down Expand Up @@ -208,13 +222,6 @@ func (d *Dialer) dial(ctx context.Context, opts *dialOptions) (conn net.Conn, er
func (d *Dialer) dialIPS(ctx context.Context, l4 l4dialer, opts *dialOptions) (conn net.Conn, err error) {
hostPort := net.JoinHostPort(opts.ips[0], opts.port)

// logAddress is the address that will be logged in case of error
logAddress := opts.hostname
if logAddress == "" {
logAddress = opts.ips[0]
}
logAddress += ":" + opts.port

if opts.shouldUseTLS {
tlsconfigCopy := opts.tlsconfig.Clone()

Expand All @@ -231,7 +238,7 @@ func (d *Dialer) dialIPS(ctx context.Context, l4 l4dialer, opts *dialOptions) (c
if opts.impersonateStrategy == impersonate.None {
l4Conn, err := l4.DialContext(ctx, opts.network, hostPort)
if err != nil {
return nil, handleDialError(err, logAddress)
return nil, d.handleDialError(err, opts)
}
TlsConn := tls.Client(l4Conn, tlsconfigCopy)
if err := TlsConn.HandshakeContext(ctx); err != nil {
Expand All @@ -241,7 +248,7 @@ func (d *Dialer) dialIPS(ctx context.Context, l4 l4dialer, opts *dialOptions) (c
} else {
nativeConn, err := l4.DialContext(ctx, opts.network, hostPort)
if err != nil {
return nil, handleDialError(err, logAddress)
return nil, d.handleDialError(err, opts)
}
// clone existing tls config
uTLSConfig := &utls.Config{
Expand Down Expand Up @@ -279,7 +286,7 @@ func (d *Dialer) dialIPS(ctx context.Context, l4 l4dialer, opts *dialOptions) (c
}
l4Conn, err := l4.DialContext(ctx, opts.network, hostPort)
if err != nil {
return nil, handleDialError(err, logAddress)
return nil, d.handleDialError(err, opts)
}
ztlsConn := ztls.Client(l4Conn, ztlsconfigCopy)
_, err = ctxutil.ExecFuncWithTwoReturns(ctx, func() (bool, error) {
Expand Down Expand Up @@ -314,10 +321,10 @@ func (d *Dialer) dialIPS(ctx context.Context, l4 l4dialer, opts *dialOptions) (c
case conn = <-connectionCh:
case err = <-errCh:
}
err = handleDialError(err, logAddress)
err = d.handleDialError(err, opts)
} else {
conn, err = l4.DialContext(ctx, opts.network, hostPort)
err = handleDialError(err, logAddress)
err = d.handleDialError(err, opts)
}
}
// fallback to ztls in case of handshake error with chrome ciphers
Expand All @@ -342,7 +349,7 @@ func (d *Dialer) dialIPS(ctx context.Context, l4 l4dialer, opts *dialOptions) (c
ztlsconfigCopy.CipherSuites = ztls.ChromeCiphers
l4Conn, err := l4.DialContext(ctx, opts.network, hostPort)
if err != nil {
return nil, handleDialError(err, logAddress)
return nil, d.handleDialError(err, opts)
}
ztlsConn := ztls.Client(l4Conn, ztlsconfigCopy)
_, err = ctxutil.ExecFuncWithTwoReturns(ctx, func() (bool, error) {
Expand All @@ -357,28 +364,36 @@ func (d *Dialer) dialIPS(ctx context.Context, l4 l4dialer, opts *dialOptions) (c
return
}

// connHash returns a hash of the connection
func connHash(network string, address string) string {
return fmt.Sprintf("%s-%s", network, address)
}

// handleDialError is a helper function to handle dial errors
// it also adds address attribute to the error
func handleDialError(err error, address string) error {
func (d *Dialer) handleDialError(err error, opts *dialOptions) error {
if err == nil {
return nil
}
errx := errkit.FromError(err)
errx = errx.SetAttr(slog.Any("address", address))
// if error kind is not set, if it is i/o timeout, set it to temporary
if errx.Kind() == nil {
errx = errx.SetAttr(slog.Any("address", opts.logAddress()))

if errx.Kind() == errkit.ErrKindUnknown {
if errx.Cause() != nil && strings.Contains(errx.Cause().Error(), "i/o timeout") {
// TODO: this is a tough call, i/o timeout happens in both cases
// it could be either temporary or permanent internally i/o timeout
// is actually a context.DeadlineExceeded error but std lib has decided to keep legacy/original error
// mark timeout errors as temporary
errx = errx.SetKind(errkit.ErrKindNetworkTemporary)

if d.dialTimeoutErrors != nil {
count, err := d.dialTimeoutErrors.GetIFPresent(opts.connHash())
if err != nil {
count = &atomic.Uint32{}
count.Store(1)
} else {
count.Add(1)
}
_ = d.dialTimeoutErrors.Set(opts.connHash(), count)

// update them to permament if they happened multiple times within 30s
if count.Load() > uint32(d.options.MaxTemporaryErrors) {
errx = errx.ResetKind().SetKind(errkit.ErrKindNetworkPermanent)
}
}
}
}
// TODO: parse and mark permanent or temporary errors
return errx
}
39 changes: 0 additions & 39 deletions fastdialer/dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,42 +52,3 @@ func testDialer(t *testing.T, options Options) {
// cleanup
fd.Close()
}

// // nolint
// func testDialerIpv6(t *testing.T, options Options) {
// // disk based
// fd, err := NewDialer(options)
// if err != nil {
// t.Fatalf("couldn't create fastdialer instance: %s", err)
// }

// // valid resolution + cache
// ctx := context.Background()
// conn, err := fd.Dial(ctx, "tcp", "ipv6.google.com:80")
// if err != nil || conn == nil {
// t.Fatalf("couldn't connect to target: %s", err)
// }
// conn.Close()
// // retrieve cached data
// data, err := fd.GetDNSData("ipv6.google.com")
// if err != nil || data == nil {
// t.Fatalf("couldn't retrieve dns data: %s", err)
// }
// if len(data.AAAA) == 0 {
// t.Error("no AAAA results found")
// }

// // test address pinning
// // this test passes, but will fail if the hard-coded ipv6 address changes
// // need to find a better way to test this
// /*
// conn, err = fd.Dial(ctx, "tcp", "ipv6.google.com:80:[2607:f8b0:4006:807::200e]")
// if err != nil || conn == nil {
// t.Errorf("couldn't connect to target: %s", err)
// }
// conn.Close()
// */

// // cleanup
// fd.Close()
// }
22 changes: 14 additions & 8 deletions fastdialer/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,22 @@ type Options struct {
OnDialCallback func(hostname, IP string)
DisableZtlsFallback bool
WithNetworkPolicyOptions *networkpolicy.Options
Logger *log.Logger // optional logger to log errors(like hostfile init error)
// optional logger to log errors(like hostfile init error)
Logger *log.Logger
// optional max temporary errors to mark as permanent
MaxTemporaryErrors int
MaxTemporaryToPermanentDuration time.Duration
}

// DefaultOptions of the cache
var DefaultOptions = Options{
BaseResolvers: DefaultResolvers,
MaxRetries: 5,
HostsFile: true,
ResolversFile: true,
CacheType: Disk,
DialerTimeout: 10 * time.Second,
DialerKeepAlive: 10 * time.Second,
BaseResolvers: DefaultResolvers,
MaxRetries: 5,
HostsFile: true,
ResolversFile: true,
CacheType: Disk,
DialerTimeout: 10 * time.Second,
DialerKeepAlive: 10 * time.Second,
MaxTemporaryErrors: 30,
MaxTemporaryToPermanentDuration: time.Minute,
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ require (
github.com/projectdiscovery/hmap v0.0.46
github.com/projectdiscovery/networkpolicy v0.0.8
github.com/projectdiscovery/retryabledns v1.0.63
github.com/projectdiscovery/utils v0.1.3
github.com/projectdiscovery/utils v0.1.4-0.20240611113448-0e2f2d33fe1c
github.com/refraction-networking/utls v1.5.4
github.com/stretchr/testify v1.9.0
github.com/tarunKoyalwar/goleak v0.0.0-20240429141123-0efa90dbdcf9
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ github.com/projectdiscovery/networkpolicy v0.0.8 h1:XvfBaBwSDNTesSfNQP9VLk3HX9I7
github.com/projectdiscovery/networkpolicy v0.0.8/go.mod h1:xnjNqhemxUPxU+UD5Jgsc3+K8IVmcqT1SJeo6UzMtkI=
github.com/projectdiscovery/retryabledns v1.0.63 h1:Ijt47ybwf+iIwul606NlPKPN+ouhOYPeA9sLHVgqLG4=
github.com/projectdiscovery/retryabledns v1.0.63/go.mod h1:lTs48OYJnMFuuBzT+3z3PrZ58K0OUBgP7Y4o3ttBwb0=
github.com/projectdiscovery/utils v0.1.3 h1:yhHkrbYZA1eOO8e+fPDUvRMS5aUIalyM3Nab7rK4tpg=
github.com/projectdiscovery/utils v0.1.3/go.mod h1:gny8RbNYXE55IoamF6thRDQ8tcJEw+r0FOGAvncz/oQ=
github.com/projectdiscovery/utils v0.1.4-0.20240611113448-0e2f2d33fe1c h1:0I/iRtu5nPYle1v8/R33pCLOrH5bziP5Bi0eZURxTQY=
github.com/projectdiscovery/utils v0.1.4-0.20240611113448-0e2f2d33fe1c/go.mod h1:mXs6OOeG9l/dVchjB2PGvQO3+wuMiE14Y/kmHeKogoM=
github.com/quic-go/quic-go v0.42.0 h1:uSfdap0eveIl8KXnipv9K7nlwZ5IqLlYOpJ58u5utpM=
github.com/quic-go/quic-go v0.42.0/go.mod h1:132kz4kL3F9vxhW3CtQJLDVwcFe5wdWeJXXijhsO57M=
github.com/refraction-networking/utls v1.5.4 h1:9k6EO2b8TaOGsQ7Pl7p9w6PUhx18/ZCeT0WNTZ7Uw4o=
Expand Down

0 comments on commit 427bffb

Please sign in to comment.