From 743bfea8a32b6269a98b4be342a15b61b5b38279 Mon Sep 17 00:00:00 2001 From: Himadri Bhattacharjee Date: Sun, 24 Sep 2023 19:23:40 +0530 Subject: [PATCH] feat: implement unthrottled concurrency using task queue - create worker goroutines specified through cli - goroutines steal incoming tasks from a channel and execute them - workers consume tasks instead of plain domain name strings - a task consists of a domain and a provider This approach spawns n green threads instead of n * len(providers). Should prevent resource usage from blowing up and help scaling. --- cmd/gau/main.go | 83 +++++++++++------------- pkg/providers/commoncrawl/commoncrawl.go | 3 +- pkg/providers/otx/otx.go | 6 +- pkg/providers/providers.go | 2 +- pkg/providers/urlscan/urlscan.go | 17 ++--- runner/flags/flags.go | 9 +-- runner/runner.go | 72 ++++++++++---------- 7 files changed, 83 insertions(+), 109 deletions(-) diff --git a/cmd/gau/main.go b/cmd/gau/main.go index f1182bc..5397bfc 100644 --- a/cmd/gau/main.go +++ b/cmd/gau/main.go @@ -2,27 +2,20 @@ package main import ( "bufio" + "io" + "os" + "sync" + "github.com/lc/gau/v2/pkg/output" "github.com/lc/gau/v2/runner" "github.com/lc/gau/v2/runner/flags" log "github.com/sirupsen/logrus" - "io" - "os" - "sync" ) func main() { - flag := flags.New() - cfg, err := flag.ReadInConfig() + cfg, err := flags.New().ReadInConfig() if err != nil { - if cfg.Verbose { - log.Warnf("error reading config: %v", err) - } - } - - pMap := make(runner.ProvidersMap) - for _, provider := range cfg.Providers { - pMap[provider] = cfg.Filters + log.Warnf("error reading config: %v", err) } config, err := cfg.ProviderConfig() @@ -30,9 +23,9 @@ func main() { log.Fatal(err) } - gau := &runner.Runner{} + gau := new(runner.Runner) - if err = gau.Init(config, pMap); err != nil { + if err = gau.Init(config, cfg.Providers, cfg.Filters); err != nil { log.Warn(err) } @@ -40,52 +33,52 @@ func main() { var out io.Writer // Handle results in background - if config.Output == "" { - out = os.Stdout - } else { - ofp, err := os.OpenFile(config.Output, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if config.Output != "" { + out, err := os.OpenFile(config.Output, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatalf("Could not open output file: %v\n", err) } - defer ofp.Close() - out = ofp + defer out.Close() + } else { + out = os.Stdout } - writeWg := &sync.WaitGroup{} + writeWg := new(sync.WaitGroup) writeWg.Add(1) - if config.JSON { - go func() { - defer writeWg.Done() + go func(JSON bool) { + defer writeWg.Done() + if JSON { output.WriteURLsJSON(out, results, config.Blacklist, config.RemoveParameters) - }() - } else { - go func() { - defer writeWg.Done() - if err = output.WriteURLs(out, results, config.Blacklist, config.RemoveParameters); err != nil { - log.Fatalf("error writing results: %v\n", err) - } - }() - } + } else if err = output.WriteURLs(out, results, config.Blacklist, config.RemoveParameters); err != nil { + log.Fatalf("error writing results: %v\n", err) + } + }(config.JSON) - domains := make(chan string) - gau.Start(domains, results) + workChan := make(chan runner.Work) + gau.Start(workChan, results) - if len(flags.Args()) > 0 { - for _, domain := range flags.Args() { - domains <- domain + domains := flags.Args() + if len(domains) > 0 { + for _, provider := range gau.Providers { + for _, domain := range domains { + workChan <- runner.NewWork(domain, provider) + } } } else { sc := bufio.NewScanner(os.Stdin) - for sc.Scan() { - domains <- sc.Text() - } + for _, provider := range gau.Providers { + for sc.Scan() { + workChan <- runner.NewWork(sc.Text(), provider) - if err := sc.Err(); err != nil { - log.Fatal(err) + if err := sc.Err(); err != nil { + log.Fatal(err) + } + } } + } - close(domains) + close(workChan) // wait for providers to fetch URLS gau.Wait() diff --git a/pkg/providers/commoncrawl/commoncrawl.go b/pkg/providers/commoncrawl/commoncrawl.go index 80b52c6..791ac16 100644 --- a/pkg/providers/commoncrawl/commoncrawl.go +++ b/pkg/providers/commoncrawl/commoncrawl.go @@ -64,11 +64,10 @@ func (c *Client) Fetch(ctx context.Context, domain string, results chan string) return nil } -paginate: for page := uint(0); page < p.Pages; page++ { select { case <-ctx.Done(): - break paginate + return nil default: logrus.WithFields(logrus.Fields{"provider": Name, "page": page}).Infof("fetching %s", domain) apiURL := c.formatURL(domain, page) diff --git a/pkg/providers/otx/otx.go b/pkg/providers/otx/otx.go index ce5b4ad..33dbbd5 100644 --- a/pkg/providers/otx/otx.go +++ b/pkg/providers/otx/otx.go @@ -46,11 +46,10 @@ func (c *Client) Name() string { } func (c *Client) Fetch(ctx context.Context, domain string, results chan string) error { -paginate: for page := uint(1); ; page++ { select { case <-ctx.Done(): - break paginate + return nil default: logrus.WithFields(logrus.Fields{"provider": Name, "page": page - 1}).Infof("fetching %s", domain) apiURL := c.formatURL(domain, page) @@ -68,11 +67,10 @@ paginate: } if !result.HasNext { - break paginate + return nil } } } - return nil } func (c *Client) formatURL(domain string, page uint) string { diff --git a/pkg/providers/providers.go b/pkg/providers/providers.go index d8c1124..055470e 100644 --- a/pkg/providers/providers.go +++ b/pkg/providers/providers.go @@ -6,7 +6,7 @@ import ( "github.com/valyala/fasthttp" ) -const Version = `2.1.2` +const Version = `2.2.0` // Provider is a generic interface for all archive fetchers type Provider interface { diff --git a/pkg/providers/urlscan/urlscan.go b/pkg/providers/urlscan/urlscan.go index 7391926..bd10204 100644 --- a/pkg/providers/urlscan/urlscan.go +++ b/pkg/providers/urlscan/urlscan.go @@ -15,8 +15,6 @@ const ( Name = "urlscan" ) -var _ providers.Provider = (*Client)(nil) - type Client struct { config *providers.Config } @@ -41,11 +39,10 @@ func (c *Client) Fetch(ctx context.Context, domain string, results chan string) header.Value = c.config.URLScan.APIKey } -paginate: for page := uint(0); ; page++ { select { case <-ctx.Done(): - break paginate + return nil default: logrus.WithFields(logrus.Fields{"provider": Name, "page": page}).Infof("fetching %s", domain) apiURL := c.formatURL(domain, searchAfter) @@ -62,7 +59,7 @@ paginate: // rate limited if result.Status == 429 { logrus.WithField("provider", "urlscan").Warnf("urlscan responded with 429, probably being rate limited") - break paginate + return nil } total := len(result.Results) @@ -73,20 +70,18 @@ paginate: if i == total-1 { sortParam := parseSort(res.Sort) - if sortParam != "" { - searchAfter = sortParam - } else { - break paginate + if sortParam == "" { + return nil } + searchAfter = sortParam } } if !result.HasMore { - break paginate + return nil } } } - return nil } func (c *Client) formatURL(domain string, after string) string { diff --git a/runner/flags/flags.go b/runner/flags/flags.go index ef4edea..5b44b8b 100644 --- a/runner/flags/flags.go +++ b/runner/flags/flags.go @@ -234,13 +234,8 @@ func (o *Options) getFlagValues(c *Config) { c.RemoveParameters = fp } - if json { - c.JSON = true - } - - if verbose { - c.Verbose = verbose - } + c.JSON = json + c.Verbose = verbose // get filter flags mc := o.viper.GetStringSlice("mc") diff --git a/runner/runner.go b/runner/runner.go index 38b6f31..a5efa79 100644 --- a/runner/runner.go +++ b/runner/runner.go @@ -13,36 +13,33 @@ import ( ) type Runner struct { - providers []providers.Provider - wg sync.WaitGroup + sync.WaitGroup - config *providers.Config + Providers []providers.Provider + threads uint ctx context.Context cancelFunc context.CancelFunc } -type ProvidersMap map[string]providers.Filters - // Init initializes the runner -func (r *Runner) Init(c *providers.Config, providerMap ProvidersMap) error { - r.config = c +func (r *Runner) Init(c *providers.Config, providers []string, filters providers.Filters) error { + r.threads = c.Threads r.ctx, r.cancelFunc = context.WithCancel(context.Background()) - for name, filters := range providerMap { + for _, name := range providers { switch name { case "urlscan": - r.providers = append(r.providers, urlscan.New(c)) + r.Providers = append(r.Providers, urlscan.New(c)) case "otx": - o := otx.New(c) - r.providers = append(r.providers, o) + r.Providers = append(r.Providers, otx.New(c)) case "wayback": - r.providers = append(r.providers, wayback.New(c, filters)) + r.Providers = append(r.Providers, wayback.New(c, filters)) case "commoncrawl": cc, err := commoncrawl.New(c, filters) if err != nil { return fmt.Errorf("error instantiating commoncrawl: %v\n", err) } - r.providers = append(r.providers, cc) + r.Providers = append(r.Providers, cc) } } @@ -50,44 +47,41 @@ func (r *Runner) Init(c *providers.Config, providerMap ProvidersMap) error { } // Starts starts the worker -func (r *Runner) Start(domains chan string, results chan string) { - for i := uint(0); i < r.config.Threads; i++ { - r.wg.Add(1) +func (r *Runner) Start(workChan chan Work, results chan string) { + for i := uint(0); i < r.threads; i++ { + r.Add(1) go func() { - defer r.wg.Done() - r.worker(r.ctx, domains, results) + defer r.Done() + r.worker(r.ctx, workChan, results) }() } } -// Wait waits for the providers to finish fetching -func (r *Runner) Wait() { - r.wg.Wait() +type Work struct { + domain string + provider providers.Provider +} + +func NewWork(domain string, provider providers.Provider) Work { + return Work{domain, provider} +} + +func (w *Work) Do(ctx context.Context, results chan string) error { + return w.provider.Fetch(ctx, w.domain, results) } // worker checks to see if the context is finished and executes the fetching process for each provider -func (r *Runner) worker(ctx context.Context, domains chan string, results chan string) { -work: +func (r *Runner) worker(ctx context.Context, workChan chan Work, results chan string) { for { select { case <-ctx.Done(): - break work - case domain, ok := <-domains: - if ok { - var wg sync.WaitGroup - for _, p := range r.providers { - wg.Add(1) - go func(p providers.Provider) { - defer wg.Done() - if err := p.Fetch(ctx, domain, results); err != nil { - logrus.WithField("provider", p.Name()).Warnf("%s - %v", domain, err) - } - }(p) - } - wg.Wait() - } + return + case work, ok := <-workChan: if !ok { - break work + return + } + if err := work.Do(ctx, results); err != nil { + logrus.WithField("provider", work.provider.Name()).Warnf("%s - %v", work.domain, err) } } }