Skip to content

Commit

Permalink
Support ETags for Azure and Google Cloud [#176, #177] (#183)
Browse files Browse the repository at this point in the history
* NewRangeReaderEtag for gocloud buckets does provider-specific operations for set/get etag and error.
* use generation ID as string etag for Google Cloud.
  • Loading branch information
bdon authored Sep 9, 2024
1 parent f9ff34e commit ddd3524
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 17 deletions.
92 changes: 75 additions & 17 deletions pmtiles/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@ import (
"os"
"path"
"path/filepath"
"strconv"
"strings"

"cloud.google.com/go/storage"
"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob/container"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/cespare/xxhash/v2"
"gocloud.dev/blob"
"google.golang.org/api/googleapi"
)

// Bucket is an abstration over a gocloud or plain HTTP bucket.
Expand Down Expand Up @@ -211,35 +218,86 @@ func (ba BucketAdapter) NewRangeReader(ctx context.Context, key string, offset,
return body, err
}

func etagToGeneration(etag string) int64 {
i, _ := strconv.ParseInt(etag, 10, 64)
return i
}

func generationToEtag(generation int64) string {
return strconv.FormatInt(generation, 10)
}

func setProviderEtag(asFunc func(interface{}) bool, etag string) {
var awsV1Req *s3.GetObjectInput
var azblobReq *azblob.DownloadStreamOptions
var gcsHandle **storage.ObjectHandle
if asFunc(&awsV1Req) {
awsV1Req.IfMatch = aws.String(etag)
} else if asFunc(&azblobReq) {
azEtag := azcore.ETag(etag)
azblobReq.AccessConditions = &azblob.AccessConditions{
ModifiedAccessConditions: &container.ModifiedAccessConditions{
IfMatch: &azEtag,
},
}
} else if asFunc(&gcsHandle) {
*gcsHandle = (*gcsHandle).If(storage.Conditions{
GenerationMatch: etagToGeneration(etag),
})
}
}

func getProviderErrorStatusCode(err error) int {
var awsV1Err awserr.RequestFailure
var azureErr *azcore.ResponseError
var gcpErr *googleapi.Error

if errors.As(err, &awsV1Err); awsV1Err != nil {
return awsV1Err.StatusCode()
} else if errors.As(err, &azureErr); azureErr != nil {
return azureErr.StatusCode
} else if errors.As(err, &gcpErr); gcpErr != nil {
return gcpErr.Code
}
return 404
}

func getProviderEtag(reader *blob.Reader) string {
var awsV1Resp s3.GetObjectOutput
var azureResp azblob.DownloadStreamResponse
var gcpResp *storage.Reader

if reader.As(&awsV1Resp) {
return *awsV1Resp.ETag
} else if reader.As(&azureResp) {
return string(*azureResp.ETag)
} else if reader.As(&gcpResp) {
return generationToEtag(gcpResp.Attrs.Generation)
}

return ""
}

func (ba BucketAdapter) NewRangeReaderEtag(ctx context.Context, key string, offset, length int64, etag string) (io.ReadCloser, string, int, error) {
reader, err := ba.Bucket.NewRangeReader(ctx, key, offset, length, &blob.ReaderOptions{
BeforeRead: func(asFunc func(interface{}) bool) error {
var req *s3.GetObjectInput
if len(etag) > 0 && asFunc(&req) {
req.IfMatch = &etag
if len(etag) > 0 {
setProviderEtag(asFunc, etag)
}
return nil
},
})
status := 206
if err != nil {
var resp awserr.RequestFailure
errors.As(err, &resp)
status = 404
if resp != nil {
status = resp.StatusCode()
if isRefreshRequiredCode(resp.StatusCode()) {
return nil, "", resp.StatusCode(), &RefreshRequiredError{resp.StatusCode()}
}
status = getProviderErrorStatusCode(err)
if isRefreshRequiredCode(status) {
return nil, "", status, &RefreshRequiredError{status}
}

return nil, "", status, err
}
resultETag := ""
var resp s3.GetObjectOutput
if reader.As(&resp) {
resultETag = *resp.ETag
}
return reader, resultETag, status, nil

return reader, getProviderEtag(reader), status, nil
}

func (ba BucketAdapter) Close() error {
Expand Down
58 changes: 58 additions & 0 deletions pmtiles/bucket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,22 @@ package pmtiles

import (
"context"
"errors"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/storage/azblob"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/stretchr/testify/assert"
_ "gocloud.dev/blob/fileblob"
"google.golang.org/api/googleapi"
)

func TestNormalizeLocalFile(t *testing.T) {
Expand Down Expand Up @@ -206,3 +213,54 @@ func TestFileShorterThan16K(t *testing.T) {
assert.Nil(t, err)
assert.Equal(t, 3, len(data))
}

func TestSetProviderEtagAws(t *testing.T) {
var awsV1Req s3.GetObjectInput
assert.Nil(t, awsV1Req.IfMatch)
asFunc := func(i interface{}) bool {
v, ok := i.(**s3.GetObjectInput)
if ok {
*v = &awsV1Req
}
return true
}
setProviderEtag(asFunc, "123")
assert.Equal(t, aws.String("123"), awsV1Req.IfMatch)
}

func TestSetProviderEtagAzure(t *testing.T) {
var azOptions azblob.DownloadStreamOptions
assert.Nil(t, azOptions.AccessConditions)
asFunc := func(i interface{}) bool {
v, ok := i.(**azblob.DownloadStreamOptions)
if ok {
*v = &azOptions
}
return ok
}
setProviderEtag(asFunc, "123")
assert.Equal(t, azcore.ETag("123"), *azOptions.AccessConditions.ModifiedAccessConditions.IfMatch)
}

func TestGetProviderErrorStatusCode(t *testing.T) {
awsErr := awserr.NewRequestFailure(awserr.New("", "", nil), 500, "")
statusCode := getProviderErrorStatusCode(awsErr)
assert.Equal(t, 500, statusCode)

azureErr := &azcore.ResponseError{StatusCode: 500}
statusCode = getProviderErrorStatusCode(azureErr)
assert.Equal(t, 500, statusCode)

gcpErr := &googleapi.Error{Code: 500}
statusCode = getProviderErrorStatusCode(gcpErr)
assert.Equal(t, 500, statusCode)

err := errors.New("generic error")
statusCode = getProviderErrorStatusCode(err)
assert.Equal(t, 404, statusCode)
}

func TestGenerationEtag(t *testing.T) {
assert.Equal(t, int64(123), etagToGeneration("123"))
assert.Equal(t, "123", generationToEtag(int64(123)))
}

0 comments on commit ddd3524

Please sign in to comment.