Skip to content

Commit

Permalink
Handle server value from put requests
Browse files Browse the repository at this point in the history
  • Loading branch information
nopcoder committed Sep 6, 2023
1 parent db2f6c4 commit 1bf441a
Showing 1 changed file with 36 additions and 34 deletions.
70 changes: 36 additions & 34 deletions pkg/block/s3/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
Expand All @@ -20,6 +20,8 @@ import (
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
"github.com/treeverse/lakefs/pkg/block"
"github.com/treeverse/lakefs/pkg/block/params"
"github.com/treeverse/lakefs/pkg/logging"
Expand All @@ -33,8 +35,7 @@ var (

type Adapter struct {
clients *ClientCache
respServer string
respServerLock sync.Mutex
respServer atomic.Pointer[string]
ServerSideEncryption string
ServerSideEncryptionKmsKeyID string
preSignedExpiry time.Duration
Expand Down Expand Up @@ -225,9 +226,10 @@ func (a *Adapter) Put(ctx context.Context, obj block.ObjectPointer, sizeBytes in
}

client := a.clients.Get(ctx, bucket)
resp, err := client.PutObject(ctx, &putObject, s3.WithAPIOptions(
v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware,
))
resp, err := client.PutObject(ctx, &putObject,
s3.WithAPIOptions(v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware),
a.registerCaptureServerMiddleware(),
)
if err != nil {
return err
}
Expand All @@ -238,6 +240,20 @@ func (a *Adapter) Put(ctx context.Context, obj block.ObjectPointer, sizeBytes in
return nil
}

// captureServerDeserializeMiddleware extracts the server name from the response and sets it on the block adapter
func (a *Adapter) captureServerDeserializeMiddleware(ctx context.Context, input middleware.DeserializeInput, handler middleware.DeserializeHandler) (middleware.DeserializeOutput, middleware.Metadata, error) {
output, m, err := handler.HandleDeserialize(ctx, input)
if err == nil {
if rawResponse, ok := output.RawResponse.(*smithyhttp.Response); ok {
s := rawResponse.Header.Get("Server")
if s != "" {
a.respServer.Store(&s)
}
}
}
return output, m, err
}

func (a *Adapter) UploadPart(ctx context.Context, obj block.ObjectPointer, sizeBytes int64, reader io.Reader, uploadID string, partNumber int) (*block.UploadPartResponse, error) {
var err error
defer reportMetrics("UploadPart", time.Now(), &sizeBytes, &err)
Expand All @@ -262,9 +278,10 @@ func (a *Adapter) UploadPart(ctx context.Context, obj block.ObjectPointer, sizeB
}

client := a.clients.Get(ctx, bucket)
resp, err := client.UploadPart(ctx, uploadPartInput, s3.WithAPIOptions(
v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware,
))
resp, err := client.UploadPart(ctx, uploadPartInput,
s3.WithAPIOptions(v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware),
a.registerCaptureServerMiddleware(),
)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -776,37 +793,15 @@ func (a *Adapter) ResolveNamespace(storageNamespace, key string, identifierType
}

func (a *Adapter) RuntimeStats() map[string]string {
a.respServerLock.Lock()
defer a.respServerLock.Unlock()
if a.respServer == "" {
respServer := aws.ToString(a.respServer.Load())
if respServer == "" {
return nil
}
return map[string]string{
"resp_server": a.respServer,
"resp_server": respServer,
}
}

// extractS3Server extracts the responding server from the response.
// TODO(barak): check how to extract server name from response using sdk v2
/*
func (a *Adapter) extractS3Server(resp *http.Response) {
if resp == nil || resp.Header == nil {
return
}
// Extract the responding server from the response.
// Expected values: "S3" from AWS, "MinIO" for MinIO. Others unknown.
server := resp.Header.Get("Server")
if server == "" {
return
}
a.respServerLock.Lock()
defer a.respServerLock.Unlock()
a.respServer = server
}
*/

func (a *Adapter) managerUpload(ctx context.Context, obj block.ObjectPointer, reader io.Reader, opts block.PutOpts) error {
bucket, key, _, err := a.extractParamsFromObj(obj)
if err != nil {
Expand Down Expand Up @@ -849,6 +844,13 @@ func (a *Adapter) extractParamsFromObj(obj block.ObjectPointer) (string, string,
return bucket, key, qk, nil
}

func (a *Adapter) registerCaptureServerMiddleware() func(*s3.Options) {
fn := middleware.DeserializeMiddlewareFunc("ResponseServerValue", a.captureServerDeserializeMiddleware)
return s3.WithAPIOptions(func(stack *middleware.Stack) error {
return stack.Deserialize.Add(fn, middleware.After)
})
}

func ExtractParamsFromQK(qk block.QualifiedKey) (string, string) {
bucket, prefix, _ := strings.Cut(qk.GetStorageNamespace(), "/")
key := qk.GetKey()
Expand Down

0 comments on commit 1bf441a

Please sign in to comment.