Skip to content

Commit

Permalink
feat: make region-client map concurrency safe using custom structure …
Browse files Browse the repository at this point in the history
…with a mutex
  • Loading branch information
lavafroth committed Sep 11, 2023
1 parent bfd0fce commit 7b7a856
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 51 deletions.
15 changes: 9 additions & 6 deletions provider/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@ package provider
import (
"context"
"errors"
"net/http"

"github.com/aws/aws-sdk-go-v2/aws"
awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/sa7mon/s3scanner/bucket"
"github.com/sa7mon/s3scanner/provider/clientmap"
log "github.com/sirupsen/logrus"
"net/http"
)

type providerAWS struct {
existsClient *s3.Client
clients map[string]*s3.Client
clients *clientmap.ClientMap
}

func (a *providerAWS) BucketExists(b *bucket.Bucket) (*bucket.Bucket, error) {
Expand Down Expand Up @@ -85,7 +87,8 @@ func NewProviderAWS() (*providerAWS, error) {
if usErr != nil {
return nil, usErr
}
pa.clients = map[string]*s3.Client{"us-east-1": usEastClient}
pa.clients = clientmap.New()
pa.clients.Set("us-east-1", usEastClient)
return pa, nil
}

Expand Down Expand Up @@ -142,8 +145,8 @@ func (a *providerAWS) newClient(region string) (*s3.Client, error) {

// TODO: This method is copied from providerLinode
func (a *providerAWS) getRegionClient(region string) (*s3.Client, error) {
c, ok := a.clients[region]
if ok {
c := a.clients.Get(region)
if c != nil {
return c, nil
}

Expand All @@ -152,6 +155,6 @@ func (a *providerAWS) getRegionClient(region string) (*s3.Client, error) {
if err != nil {
return nil, err
}
a.clients[region] = c // TODO: Make sure this is thread-safe
a.clients.Set(region, c)
return c, nil
}
54 changes: 54 additions & 0 deletions provider/clientmap/clientmap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package clientmap

import (
"github.com/aws/aws-sdk-go-v2/service/s3"
"sync"
)

type ClientMap struct {
sync.Mutex
inner map[string]*s3.Client
}

func New() *ClientMap {
return &ClientMap{
Mutex: sync.Mutex{},
inner: make(map[string]*s3.Client),
}
}

func WithCapacity(cap int) *ClientMap {
return &ClientMap{
Mutex: sync.Mutex{},
inner: make(map[string]*s3.Client, cap),
}
}

func (m *ClientMap) Get(key string) *s3.Client {
m.Lock()
defer m.Unlock()
if v, ok := m.inner[key]; ok {
return v
}
return nil
}

func (m *ClientMap) Set(key string, value *s3.Client) {
m.Lock()
m.inner[key] = value
m.Unlock()
}

func (m *ClientMap) Len() int {
m.Lock()
defer m.Unlock()
return len(m.inner)
}

func (m *ClientMap) Each(fn func(region string, client *s3.Client)) {
m.Lock()
for region, client := range m.inner {
fn(region, client)
}
m.Unlock()
}
18 changes: 8 additions & 10 deletions provider/custom.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ package provider
import (
"errors"
"fmt"
"strings"

"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/sa7mon/s3scanner/bucket"
"strings"
"github.com/sa7mon/s3scanner/provider/clientmap"
)

type CustomProvider struct {
regions []string
clients map[string]*s3.Client
clients *clientmap.ClientMap
insecure bool
addressStyle int
endpointFormat string
Expand Down Expand Up @@ -66,11 +68,7 @@ func (cp CustomProvider) Enumerate(b *bucket.Bucket) error {
}

func (cp *CustomProvider) getRegionClient(region string) *s3.Client {
c, ok := cp.clients[region]
if ok {
return c
}
return nil
return cp.clients.Get(region)
}

/*
Expand Down Expand Up @@ -98,15 +96,15 @@ func NewCustomProvider(addressStyle string, insecure bool, regions []string, end
return cp, nil
}

func (cp *CustomProvider) newClients() (map[string]*s3.Client, error) {
clients := make(map[string]*s3.Client, len(cp.regions))
func (cp *CustomProvider) newClients() (*clientmap.ClientMap, error) {
clients := clientmap.WithCapacity(len(cp.regions))
for _, r := range cp.regions {
regionUrl := strings.Replace(cp.endpointFormat, "$REGION", r, -1)
client, err := newNonAWSClient(cp, regionUrl)
if err != nil {
return nil, err
}
clients[r] = client
clients.Set(r, client)
}

return clients, nil
Expand Down
16 changes: 7 additions & 9 deletions provider/digitalocean.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package provider
import (
"errors"
"fmt"

"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/sa7mon/s3scanner/bucket"
"github.com/sa7mon/s3scanner/provider/clientmap"
)

type providerDO struct {
regions []string
clients map[string]*s3.Client
clients *clientmap.ClientMap
}

func (pdo providerDO) Insecure() bool {
Expand Down Expand Up @@ -66,25 +68,21 @@ func (pdo *providerDO) Regions() []string {
return urls
}

func (pdo *providerDO) newClients() (map[string]*s3.Client, error) {
clients := make(map[string]*s3.Client, len(pdo.regions))
func (pdo *providerDO) newClients() (*clientmap.ClientMap, error) {
clients := clientmap.WithCapacity(len(pdo.regions))
for _, r := range pdo.Regions() {
client, err := newNonAWSClient(pdo, r)
if err != nil {
return nil, err
}
clients[r] = client
clients.Set(r, client)
}

return clients, nil
}

func (pdo *providerDO) getRegionClient(region string) *s3.Client {
c, ok := pdo.clients[region]
if ok {
return c
}
return nil
return pdo.clients.Get(region)
}

func NewProviderDO() (*providerDO, error) {
Expand Down
16 changes: 7 additions & 9 deletions provider/dreamhost.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package provider
import (
"errors"
"fmt"

"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/sa7mon/s3scanner/bucket"
"github.com/sa7mon/s3scanner/provider/clientmap"
)

type ProviderDreamhost struct {
regions []string
clients map[string]*s3.Client
clients *clientmap.ClientMap
}

func (p ProviderDreamhost) Insecure() bool {
Expand Down Expand Up @@ -46,11 +48,7 @@ func (p ProviderDreamhost) Scan(bucket *bucket.Bucket, doDestructiveChecks bool)
}

func (p ProviderDreamhost) getRegionClient(region string) *s3.Client {
c, ok := p.clients[region]
if ok {
return c
}
return nil
return p.clients.Get(region)
}

func (p ProviderDreamhost) Enumerate(b *bucket.Bucket) error {
Expand All @@ -74,14 +72,14 @@ func (p ProviderDreamhost) Regions() []string {
return urls
}

func (p *ProviderDreamhost) newClients() (map[string]*s3.Client, error) {
clients := make(map[string]*s3.Client, len(p.regions))
func (p *ProviderDreamhost) newClients() (*clientmap.ClientMap, error) {
clients := clientmap.WithCapacity(len(p.regions))
for _, r := range p.Regions() {
client, err := newNonAWSClient(p, r)
if err != nil {
return nil, err
}
clients[r] = client
clients.Set(r, client)
}

return clients, nil
Expand Down
6 changes: 5 additions & 1 deletion provider/gcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package provider

import (
"errors"

"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/sa7mon/s3scanner/bucket"
"github.com/sa7mon/s3scanner/provider/clientmap"
)

// GCP like AWS, has a "universal" endpoint, but unlike AWS GCP does not require you to follow a redirect to the
Expand All @@ -30,7 +32,9 @@ func (g GCP) BucketExists(b *bucket.Bucket) (*bucket.Bucket, error) {
if !bucket.IsValidS3BucketName(b.Name) {
return nil, errors.New("invalid bucket name")
}
exists, region, err := bucketExists(map[string]*s3.Client{"default": g.client}, b)
clients := clientmap.New()
clients.Set("default", g.client)
exists, region, err := bucketExists(clients, b)
if err != nil {
return b, err
}
Expand Down
16 changes: 7 additions & 9 deletions provider/linode.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package provider
import (
"errors"
"fmt"

"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/sa7mon/s3scanner/bucket"
"github.com/sa7mon/s3scanner/provider/clientmap"
)

type providerLinode struct {
regions []string
clients map[string]*s3.Client
clients *clientmap.ClientMap
}

func NewProviderLinode() (*providerLinode, error) {
Expand All @@ -25,11 +27,7 @@ func NewProviderLinode() (*providerLinode, error) {
}

func (pl *providerLinode) getRegionClient(region string) *s3.Client {
c, ok := pl.clients[region]
if ok {
return c
}
return nil
return pl.clients.Get(region)
}

func (pl *providerLinode) BucketExists(b *bucket.Bucket) (*bucket.Bucket, error) {
Expand Down Expand Up @@ -61,14 +59,14 @@ func (pl *providerLinode) Enumerate(b *bucket.Bucket) error {
return nil
}

func (pl *providerLinode) newClients() (map[string]*s3.Client, error) {
clients := make(map[string]*s3.Client, len(pl.regions))
func (pl *providerLinode) newClients() (*clientmap.ClientMap, error) {
clients := clientmap.WithCapacity(len(pl.regions))
for _, r := range pl.Regions() {
client, err := newNonAWSClient(pl, r)
if err != nil {
return nil, err
}
clients[r] = client
clients.Set(r, client)
}

return clients, nil
Expand Down
16 changes: 9 additions & 7 deletions provider/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (
"crypto/tls"
"errors"
"fmt"
"net/http"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
"github.com/aws/aws-sdk-go-v2/config"
Expand All @@ -13,9 +16,8 @@ import (
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/sa7mon/s3scanner/bucket"
"github.com/sa7mon/s3scanner/permission"
"github.com/sa7mon/s3scanner/provider/clientmap"
log "github.com/sirupsen/logrus"
"net/http"
"time"
)

const (
Expand Down Expand Up @@ -195,13 +197,13 @@ func checkPermissions(client *s3.Client, b *bucket.Bucket, doDestructiveChecks b
return nil
}

func bucketExists(clients map[string]*s3.Client, b *bucket.Bucket) (bool, string, error) {
func bucketExists(clients *clientmap.ClientMap, b *bucket.Bucket) (bool, string, error) {
// TODO: Should this return a client or a region name? If region name, we'll need GetClient(region)
// TODO: Add region priority - order in which to check. maps are not ordered
results := make(chan bucketCheckResult, len(clients))
results := make(chan bucketCheckResult, clients.Len())
e := make(chan error, 1)

for region, client := range clients {
clients.Each(func(region string, client *s3.Client) {
go func(bucketName string, client *s3.Client, region string) {
logFields := log.Fields{
"bucket_name": b.Name,
Expand Down Expand Up @@ -234,9 +236,9 @@ func bucketExists(clients map[string]*s3.Client, b *bucket.Bucket) (bool, string
e <- err
}
}(b.Name, client, region)
}
})

for range clients {
for i := 0; i < clients.Len(); i++ {
select {
case err := <-e:
return false, "", err
Expand Down

0 comments on commit 7b7a856

Please sign in to comment.