diff --git a/br/pkg/lightning/backend/external/BUILD.bazel b/br/pkg/lightning/backend/external/BUILD.bazel index 89950d8348413..f2f3b0eb6317c 100644 --- a/br/pkg/lightning/backend/external/BUILD.bazel +++ b/br/pkg/lightning/backend/external/BUILD.bazel @@ -33,6 +33,7 @@ go_library( "//br/pkg/storage", "//pkg/kv", "//pkg/metrics", + "//pkg/util", "//pkg/util/hack", "//pkg/util/logutil", "//pkg/util/size", diff --git a/br/pkg/lightning/backend/external/byte_reader.go b/br/pkg/lightning/backend/external/byte_reader.go index 3ba4a978390c4..a982c4349513d 100644 --- a/br/pkg/lightning/backend/external/byte_reader.go +++ b/br/pkg/lightning/backend/external/byte_reader.go @@ -209,6 +209,9 @@ func (r *byteReader) readNBytes(n int) ([]byte, error) { return bs[0], nil } // need to flatten bs + if n <= 0 { + return nil, errors.Errorf("illegal n (%d) when reading from external storage", n) + } if n > int(size.GB) { return nil, errors.Errorf("read %d bytes from external storage, exceed max limit %d", n, size.GB) } diff --git a/br/pkg/lightning/backend/external/reader.go b/br/pkg/lightning/backend/external/reader.go index 34370d005fca9..d5f3d05ff55c6 100644 --- a/br/pkg/lightning/backend/external/reader.go +++ b/br/pkg/lightning/backend/external/reader.go @@ -20,13 +20,14 @@ import ( "io" "time" + "github.com/pingcap/errors" "github.com/pingcap/tidb/br/pkg/lightning/log" "github.com/pingcap/tidb/br/pkg/membuf" "github.com/pingcap/tidb/br/pkg/storage" "github.com/pingcap/tidb/pkg/metrics" + "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/logutil" "go.uber.org/zap" - "golang.org/x/sync/errgroup" ) func readAllData( @@ -65,12 +66,13 @@ func readAllData( if err != nil { return err } - var eg errgroup.Group + eg, egCtx := util.NewErrorGroupWithRecoverWithCtx(ctx) + // TODO(lance6716): limit the concurrency of eg to 30 does not help for i := range dataFiles { i := i eg.Go(func() error { - return readOneFile( - ctx, + err2 := readOneFile( + egCtx, storage, dataFiles[i], startKey, @@ -80,6 +82,7 @@ func readAllData( bufPool, output, ) + return errors.Annotatef(err2, "failed to read file %s", dataFiles[i]) }) } return eg.Wait() diff --git a/br/pkg/lightning/backend/external/writer.go b/br/pkg/lightning/backend/external/writer.go index 25b15f23699a5..514937393978b 100644 --- a/br/pkg/lightning/backend/external/writer.go +++ b/br/pkg/lightning/backend/external/writer.go @@ -197,7 +197,6 @@ func (b *WriterBuilder) Build( filenamePrefix: filenamePrefix, keyAdapter: keyAdapter, writerID: writerID, - kvStore: nil, onClose: b.onClose, closed: false, multiFileStats: make([]MultipleFilesStat, 1), @@ -293,8 +292,7 @@ type Writer struct { filenamePrefix string keyAdapter common.KeyAdapter - kvStore *KeyValueStore - rc *rangePropertiesCollector + rc *rangePropertiesCollector memSizeLimit uint64 @@ -400,88 +398,53 @@ func (w *Writer) recordMinMax(newMin, newMax tidbkv.Key, size uint64) { w.totalSize += size } +const flushKVsRetryTimes = 3 + func (w *Writer) flushKVs(ctx context.Context, fromClose bool) (err error) { if len(w.kvLocations) == 0 { return nil } - logger := logutil.Logger(ctx) - dataFile, statFile, dataWriter, statWriter, err := w.createStorageWriter(ctx) - if err != nil { - return err - } - - var ( - savedBytes uint64 - statSize int - sortDuration, writeDuration time.Duration - writeStartTime time.Time + logger := logutil.Logger(ctx).With( + zap.String("writer-id", w.writerID), + zap.Int("sequence-number", w.currentSeq), ) - savedBytes = w.batchSize - startTs := time.Now() - - kvCnt := len(w.kvLocations) - defer func() { - w.currentSeq++ - err1, err2 := dataWriter.Close(ctx), statWriter.Close(ctx) - if err != nil { - return - } - if err1 != nil { - logger.Error("close data writer failed", zap.Error(err1)) - err = err1 - return - } - if err2 != nil { - logger.Error("close stat writer failed", zap.Error(err2)) - err = err2 - return - } - writeDuration = time.Since(writeStartTime) - logger.Info("flush kv", - zap.Uint64("bytes", savedBytes), - zap.Int("kv-cnt", kvCnt), - zap.Int("stat-size", statSize), - zap.Duration("sort-time", sortDuration), - zap.Duration("write-time", writeDuration), - zap.String("sort-speed(kv/s)", getSpeed(uint64(kvCnt), sortDuration.Seconds(), false)), - zap.String("write-speed(bytes/s)", getSpeed(savedBytes, writeDuration.Seconds(), true)), - zap.String("writer-id", w.writerID), - ) - metrics.GlobalSortWriteToCloudStorageDuration.WithLabelValues("write").Observe(writeDuration.Seconds()) - metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("write").Observe(float64(savedBytes) / 1024.0 / 1024.0 / writeDuration.Seconds()) - metrics.GlobalSortWriteToCloudStorageDuration.WithLabelValues("sort_and_write").Observe(time.Since(startTs).Seconds()) - metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("sort_and_write").Observe(float64(savedBytes) / 1024.0 / 1024.0 / time.Since(startTs).Seconds()) - }() - sortStart := time.Now() slices.SortFunc(w.kvLocations, func(i, j membuf.SliceLocation) int { return bytes.Compare(w.getKeyByLoc(i), w.getKeyByLoc(j)) }) - sortDuration = time.Since(sortStart) - - writeStartTime = time.Now() + sortDuration := time.Since(sortStart) metrics.GlobalSortWriteToCloudStorageDuration.WithLabelValues("sort").Observe(sortDuration.Seconds()) - metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("sort").Observe(float64(savedBytes) / 1024.0 / 1024.0 / sortDuration.Seconds()) - w.kvStore, err = NewKeyValueStore(ctx, dataWriter, w.rc) - if err != nil { - return err - } - - for _, pair := range w.kvLocations { - err = w.kvStore.addEncodedData(w.kvBuffer.GetSlice(pair)) - if err != nil { - return err + metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("sort").Observe(float64(w.batchSize) / 1024.0 / 1024.0 / sortDuration.Seconds()) + + writeStartTime := time.Now() + var dataFile, statFile string + for i := 0; i < flushKVsRetryTimes; i++ { + dataFile, statFile, err = w.flushSortedKVs(ctx) + if err == nil { + break } + logger.Warn("flush sorted kv failed", + zap.Error(err), + zap.Int("retry-count", i), + ) } - - w.kvStore.Close() - encodedStat := w.rc.encode() - statSize = len(encodedStat) - _, err = statWriter.Write(ctx, encodedStat) if err != nil { return err } + writeDuration := time.Since(writeStartTime) + kvCnt := len(w.kvLocations) + logger.Info("flush kv", + zap.Uint64("bytes", w.batchSize), + zap.Int("kv-cnt", kvCnt), + zap.Duration("sort-time", sortDuration), + zap.Duration("write-time", writeDuration), + zap.String("sort-speed(kv/s)", getSpeed(uint64(kvCnt), sortDuration.Seconds(), false)), + zap.String("writer-id", w.writerID), + ) + totalDuration := time.Since(sortStart) + metrics.GlobalSortWriteToCloudStorageDuration.WithLabelValues("sort_and_write").Observe(totalDuration.Seconds()) + metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("sort_and_write").Observe(float64(w.batchSize) / 1024.0 / 1024.0 / totalDuration.Seconds()) minKey, maxKey := w.getKeyByLoc(w.kvLocations[0]), w.getKeyByLoc(w.kvLocations[len(w.kvLocations)-1]) w.recordMinMax(minKey, maxKey, uint64(w.kvSize)) @@ -507,9 +470,73 @@ func (w *Writer) flushKVs(ctx context.Context, fromClose bool) (err error) { w.kvBuffer.Reset() w.rc.reset() w.batchSize = 0 + w.currentSeq++ return nil } +func (w *Writer) flushSortedKVs(ctx context.Context) (string, string, error) { + logger := logutil.Logger(ctx).With( + zap.String("writer-id", w.writerID), + zap.Int("sequence-number", w.currentSeq), + ) + writeStartTime := time.Now() + dataFile, statFile, dataWriter, statWriter, err := w.createStorageWriter(ctx) + if err != nil { + return "", "", err + } + defer func() { + // close the writers when meet error. If no error happens, writers will + // be closed outside and assigned to nil. + if dataWriter != nil { + _ = dataWriter.Close(ctx) + } + if statWriter != nil { + _ = statWriter.Close(ctx) + } + }() + kvStore, err := NewKeyValueStore(ctx, dataWriter, w.rc) + if err != nil { + return "", "", err + } + + for _, pair := range w.kvLocations { + err = kvStore.addEncodedData(w.kvBuffer.GetSlice(pair)) + if err != nil { + return "", "", err + } + } + + kvStore.Close() + encodedStat := w.rc.encode() + statSize := len(encodedStat) + _, err = statWriter.Write(ctx, encodedStat) + if err != nil { + return "", "", err + } + err = dataWriter.Close(ctx) + dataWriter = nil + if err != nil { + return "", "", err + } + err = statWriter.Close(ctx) + statWriter = nil + if err != nil { + return "", "", err + } + + writeDuration := time.Since(writeStartTime) + logger.Info("flush sorted kv", + zap.Uint64("bytes", w.batchSize), + zap.Int("stat-size", statSize), + zap.Duration("write-time", writeDuration), + zap.String("write-speed(bytes/s)", getSpeed(w.batchSize, writeDuration.Seconds(), true)), + ) + metrics.GlobalSortWriteToCloudStorageDuration.WithLabelValues("write").Observe(writeDuration.Seconds()) + metrics.GlobalSortWriteToCloudStorageRate.WithLabelValues("write").Observe(float64(w.batchSize) / 1024.0 / 1024.0 / writeDuration.Seconds()) + + return dataFile, statFile, nil +} + func (w *Writer) getKeyByLoc(loc membuf.SliceLocation) []byte { block := w.kvBuffer.GetSlice(loc) keyLen := binary.BigEndian.Uint64(block[:lengthBytes]) diff --git a/br/pkg/storage/BUILD.bazel b/br/pkg/storage/BUILD.bazel index 3c5dd4d662705..42e96b6126158 100644 --- a/br/pkg/storage/BUILD.bazel +++ b/br/pkg/storage/BUILD.bazel @@ -7,6 +7,7 @@ go_library( "compress.go", "flags.go", "gcs.go", + "gcs_extra.go", "hdfs.go", "helper.go", "ks3.go", @@ -49,6 +50,7 @@ go_library( "@com_github_azure_azure_sdk_for_go_sdk_storage_azblob//bloberror", "@com_github_azure_azure_sdk_for_go_sdk_storage_azblob//blockblob", "@com_github_azure_azure_sdk_for_go_sdk_storage_azblob//container", + "@com_github_go_resty_resty_v2//:resty", "@com_github_google_uuid//:uuid", "@com_github_klauspost_compress//gzip", "@com_github_klauspost_compress//snappy", diff --git a/br/pkg/storage/gcs.go b/br/pkg/storage/gcs.go index 915537eec9166..5afe739260b57 100644 --- a/br/pkg/storage/gcs.go +++ b/br/pkg/storage/gcs.go @@ -99,6 +99,7 @@ func (options *GCSBackendOptions) parseFromFlags(flags *pflag.FlagSet) error { type GCSStorage struct { gcs *backuppb.GCS bucket *storage.BucketHandle + cli *storage.Client } // GetBucketHandle gets the handle to the GCS API on the bucket. @@ -272,12 +273,29 @@ func (s *GCSStorage) URI() string { } // Create implements ExternalStorage interface. -func (s *GCSStorage) Create(ctx context.Context, name string, _ *WriterOption) (ExternalFileWriter, error) { - object := s.objectName(name) - wc := s.bucket.Object(object).NewWriter(ctx) - wc.StorageClass = s.gcs.StorageClass - wc.PredefinedACL = s.gcs.PredefinedAcl - return newFlushStorageWriter(wc, &emptyFlusher{}, wc), nil +func (s *GCSStorage) Create(ctx context.Context, name string, wo *WriterOption) (ExternalFileWriter, error) { + // NewGCSWriter requires real testing environment on Google Cloud. + mockGCS := intest.InTest && strings.Contains(s.gcs.GetEndpoint(), "127.0.0.1") + if wo == nil || wo.Concurrency <= 1 || mockGCS { + object := s.objectName(name) + wc := s.bucket.Object(object).NewWriter(ctx) + wc.StorageClass = s.gcs.StorageClass + wc.PredefinedACL = s.gcs.PredefinedAcl + return newFlushStorageWriter(wc, &emptyFlusher{}, wc), nil + } + uri := s.objectName(name) + // 5MB is the minimum part size for GCS. + partSize := int64(gcsMinimumChunkSize) + if wo.PartSize > partSize { + partSize = wo.PartSize + } + w, err := NewGCSWriter(ctx, s.cli, uri, partSize, wo.Concurrency, s.gcs.Bucket) + if err != nil { + return nil, errors.Trace(err) + } + fw := newFlushStorageWriter(w, &emptyFlusher{}, w) + bw := newBufferedWriter(fw, int(partSize), NoCompression) + return bw, nil } // Rename file name from oldFileName to newFileName. @@ -371,7 +389,7 @@ skipHandleCred: // so we need find sst in slash directory gcs.Prefix += "//" } - return &GCSStorage{gcs: gcs, bucket: bucket}, nil + return &GCSStorage{gcs: gcs, bucket: bucket, cli: client}, nil } func hasSSTFiles(ctx context.Context, bucket *storage.BucketHandle, prefix string) bool { diff --git a/br/pkg/storage/gcs_extra.go b/br/pkg/storage/gcs_extra.go new file mode 100644 index 0000000000000..15b3ef5715d5a --- /dev/null +++ b/br/pkg/storage/gcs_extra.go @@ -0,0 +1,419 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Learned from https://github.com/liqiuqing/gcsmpu + +package storage + +import ( + "bytes" + "context" + "encoding/xml" + "fmt" + "net" + "net/http" + "net/url" + "runtime" + "slices" + "strconv" + "sync" + "time" + + "cloud.google.com/go/storage" + "github.com/go-resty/resty/v2" + "go.uber.org/atomic" +) + +// GCSWriter uses XML multipart upload API to upload a single file. +// https://cloud.google.com/storage/docs/multipart-uploads. +// GCSWriter will attempt to cancel uploads that fail due to an exception. +// If the upload fails in a way that precludes cancellation, such as a +// hardware failure, process termination, or power outage, then the incomplete +// upload may persist indefinitely. To mitigate this, set the +// `AbortIncompleteMultipartUpload` with a nonzero `Age` in bucket lifecycle +// rules, or refer to the XML API documentation linked above to learn more +// about how to list and delete individual downloads. +type GCSWriter struct { + uploadBase + mutex sync.Mutex + xmlMPUParts []*xmlMPUPart + wg sync.WaitGroup + err atomic.Error + chunkSize int64 + workers int + totalSize int64 + uploadID string + chunkCh chan chunk + curPart int +} + +// NewGCSWriter returns a GCSWriter which uses GCS multipart upload API behind the scene. +func NewGCSWriter( + ctx context.Context, + cli *storage.Client, + uri string, + partSize int64, + parallelCnt int, + bucketName string, +) (*GCSWriter, error) { + if partSize < gcsMinimumChunkSize || partSize > gcsMaximumChunkSize { + return nil, fmt.Errorf( + "invalid chunk size: %d. Chunk size must be between %d and %d", + partSize, gcsMinimumChunkSize, gcsMaximumChunkSize, + ) + } + + w := &GCSWriter{ + uploadBase: uploadBase{ + ctx: ctx, + cli: cli, + bucket: bucketName, + blob: uri, + retry: defaultRetry, + signedURLExpiry: defaultSignedURLExpiry, + }, + chunkSize: partSize, + workers: parallelCnt, + } + if err := w.init(); err != nil { + return nil, fmt.Errorf("failed to initiate GCSWriter: %w", err) + } + + return w, nil +} + +func (w *GCSWriter) init() error { + opts := &storage.SignedURLOptions{ + Scheme: storage.SigningSchemeV4, + Method: "POST", + Expires: time.Now().Add(w.signedURLExpiry), + QueryParameters: url.Values{mpuInitiateQuery: []string{""}}, + } + u, err := w.cli.Bucket(w.bucket).SignedURL(w.blob, opts) + if err != nil { + return fmt.Errorf("Bucket(%q).SignedURL: %s", w.bucket, err) + } + + client := resty.New() + resp, err := client.R().Post(u) + if err != nil { + return fmt.Errorf("POST request failed: %s", err) + } + + if resp.StatusCode() != http.StatusOK { + return fmt.Errorf("POST request returned non-OK status: %d", resp.StatusCode()) + } + body := resp.Body() + + result := InitiateMultipartUploadResult{} + err = xml.Unmarshal(body, &result) + if err != nil { + return fmt.Errorf("failed to unmarshal response body: %s", err) + } + + uploadID := result.UploadId + w.uploadID = uploadID + w.chunkCh = make(chan chunk) + for i := 0; i < w.workers; i++ { + w.wg.Add(1) + go w.readChunk(w.chunkCh) + } + w.curPart = 1 + return nil +} + +func (w *GCSWriter) readChunk(ch chan chunk) { + defer w.wg.Done() + for { + data, ok := <-ch + if !ok { + break + } + + select { + case <-w.ctx.Done(): + data.cleanup() + return + default: + part := &xmlMPUPart{ + uploadBase: w.uploadBase, + uploadID: w.uploadID, + buf: data.buf, + partNumber: data.num, + } + if w.err.Load() == nil { + if err := part.Upload(); err != nil { + w.err.Store(err) + } + } + part.buf = nil + w.appendMPUPart(part) + data.cleanup() + } + } +} + +// Write uploads given bytes as a part to Google Cloud Storage. Write is not +// concurrent safe. +func (w *GCSWriter) Write(p []byte) (n int, err error) { + if w.curPart > gcsMaximumParts { + err = fmt.Errorf("exceed maximum parts %d", gcsMaximumParts) + if w.err.Load() == nil { + w.err.Store(err) + } + return 0, err + } + buf := make([]byte, len(p)) + copy(buf, p) + w.chunkCh <- chunk{ + buf: buf, + num: w.curPart, + cleanup: func() {}, + } + w.curPart++ + return len(p), nil +} + +// Close finishes the upload. +func (w *GCSWriter) Close() error { + close(w.chunkCh) + w.wg.Wait() + + if err := w.err.Load(); err != nil { + return err + } + + err := w.finalizeXMLMPU() + if err == nil { + return nil + } + errC := w.cancel() + if errC != nil { + return fmt.Errorf("failed to finalize multipart upload: %s, Failed to cancel multipart upload: %s", err, errC) + } + return fmt.Errorf("failed to finalize multipart upload: %s", err) +} + +const ( + mpuInitiateQuery = "uploads" + mpuPartNumberQuery = "partNumber" + mpuUploadIDQuery = "uploadId" +) + +type uploadBase struct { + cli *storage.Client + ctx context.Context + bucket string + blob string + retry int + signedURLExpiry time.Duration +} + +const ( + defaultRetry = 3 + defaultSignedURLExpiry = 6 * time.Hour + + gcsMinimumChunkSize = 5 * 1024 * 1024 // 5 MB + gcsMaximumChunkSize = 5 * 1024 * 1024 * 1024 // 5 GB + gcsMaximumParts = 10000 +) + +type InitiateMultipartUploadResult struct { + XMLName xml.Name `xml:"InitiateMultipartUploadResult"` + Text string `xml:",chardata"` + Xmlns string `xml:"xmlns,attr"` + Bucket string `xml:"Bucket"` + Key string `xml:"Key"` + UploadId string `xml:"UploadId"` +} + +type Part struct { + Text string `xml:",chardata"` + PartNumber int `xml:"PartNumber"` + ETag string `xml:"ETag"` +} + +type CompleteMultipartUpload struct { + XMLName xml.Name `xml:"CompleteMultipartUpload"` + Text string `xml:",chardata"` + Parts []Part `xml:"Part"` +} + +func (w *GCSWriter) finalizeXMLMPU() error { + finalXMLRoot := CompleteMultipartUpload{ + Parts: make([]Part, 0, len(w.xmlMPUParts)), + } + slices.SortFunc(w.xmlMPUParts, func(a, b *xmlMPUPart) int { + return a.partNumber - b.partNumber + }) + for _, part := range w.xmlMPUParts { + part := Part{ + PartNumber: part.partNumber, + ETag: part.etag, + } + finalXMLRoot.Parts = append(finalXMLRoot.Parts, part) + } + + xmlBytes, err := xml.Marshal(finalXMLRoot) + if err != nil { + return fmt.Errorf("failed to encode XML: %v", err) + } + + opts := &storage.SignedURLOptions{ + Scheme: storage.SigningSchemeV4, + Method: "POST", + Expires: time.Now().Add(w.signedURLExpiry), + QueryParameters: url.Values{mpuUploadIDQuery: []string{w.uploadID}}, + } + u, err := w.cli.Bucket(w.bucket).SignedURL(w.blob, opts) + if err != nil { + return fmt.Errorf("Bucket(%q).SignedURL: %s", w.bucket, err) + } + + client := resty.New() + resp, err := client.R().SetBody(xmlBytes).Post(u) + if err != nil { + return fmt.Errorf("POST request failed: %s", err) + } + + if resp.StatusCode() != http.StatusOK { + return fmt.Errorf("POST request returned non-OK status: %d, body: %s", resp.StatusCode(), resp.String()) + } + return nil +} + +type chunk struct { + buf []byte + num int + cleanup func() +} + +func (w *GCSWriter) appendMPUPart(part *xmlMPUPart) { + w.mutex.Lock() + defer w.mutex.Unlock() + + w.xmlMPUParts = append(w.xmlMPUParts, part) +} + +func (w *GCSWriter) cancel() error { + opts := &storage.SignedURLOptions{ + Scheme: storage.SigningSchemeV4, + Method: "DELETE", + Expires: time.Now().Add(w.signedURLExpiry), + QueryParameters: url.Values{mpuUploadIDQuery: []string{w.uploadID}}, + } + u, err := w.cli.Bucket(w.bucket).SignedURL(w.blob, opts) + if err != nil { + return fmt.Errorf("Bucket(%q).SignedURL: %s", w.bucket, err) + } + + client := resty.New() + resp, err := client.R().Delete(u) + if err != nil { + return fmt.Errorf("DELETE request failed: %s", err) + } + + if resp.StatusCode() != http.StatusNoContent { + return fmt.Errorf("DELETE request returned non-204 status: %d", resp.StatusCode()) + } + + return nil +} + +type xmlMPUPart struct { + uploadBase + buf []byte + uploadID string + partNumber int + etag string +} + +func (p *xmlMPUPart) Clone() *xmlMPUPart { + return &xmlMPUPart{ + uploadBase: p.uploadBase, + uploadID: p.uploadID, + buf: p.buf, + partNumber: p.partNumber, + } +} + +func (p *xmlMPUPart) Upload() error { + var err error + for i := 0; i < p.retry; i++ { + err = p.upload() + if err == nil { + return nil + } + } + + return fmt.Errorf("failed to upload part %d: %w", p.partNumber, err) +} + +func (p *xmlMPUPart) upload() error { + opts := &storage.SignedURLOptions{ + Scheme: storage.SigningSchemeV4, + Method: "PUT", + Expires: time.Now().Add(p.signedURLExpiry), + QueryParameters: url.Values{ + mpuUploadIDQuery: []string{p.uploadID}, + mpuPartNumberQuery: []string{strconv.Itoa(p.partNumber)}, + }, + } + + u, err := p.cli.Bucket(p.bucket).SignedURL(p.blob, opts) + if err != nil { + return fmt.Errorf("Bucket(%q).SignedURL: %s", p.bucket, err) + } + + req, err := http.NewRequest("PUT", u, bytes.NewReader(p.buf)) + if err != nil { + return fmt.Errorf("PUT request failed: %s", err) + } + req = req.WithContext(p.ctx) + + client := &http.Client{ + Transport: createTransport(nil), + } + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("PUT request failed: %s", err) + } + defer resp.Body.Close() + + p.etag = resp.Header.Get("ETag") + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("PUT request returned non-OK status: %d", resp.StatusCode) + } + return nil +} + +func createTransport(localAddr net.Addr) *http.Transport { + dialer := &net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + } + if localAddr != nil { + dialer.LocalAddr = localAddr + } + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: dialer.DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1, + } +} diff --git a/br/pkg/storage/gcs_test.go b/br/pkg/storage/gcs_test.go index f9346caffde80..6f9cbfa7ef687 100644 --- a/br/pkg/storage/gcs_test.go +++ b/br/pkg/storage/gcs_test.go @@ -3,7 +3,10 @@ package storage import ( + "bytes" "context" + "crypto/rand" + "flag" "fmt" "io" "os" @@ -460,3 +463,39 @@ func TestReadRange(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte("234"), content[:n]) } + +var testingStorageURI = flag.String("testing-storage-uri", "", "the URI of the storage used for testing") + +func openTestingStorage(t *testing.T) ExternalStorage { + if *testingStorageURI == "" { + t.Skip("testingStorageURI is not set") + } + s, err := NewFromURL(context.Background(), *testingStorageURI) + require.NoError(t, err) + return s +} + +func TestMultiPartUpload(t *testing.T) { + ctx := context.Background() + + s := openTestingStorage(t) + if _, ok := s.(*GCSStorage); !ok { + t.Skipf("only test GCSStorage, got %T", s) + } + + filename := "TestMultiPartUpload" + // just get some random content, use any seed is enough + data := make([]byte, 100*1024*1024) + rand.Read(data) + w, err := s.Create(ctx, filename, &WriterOption{Concurrency: 10}) + require.NoError(t, err) + _, err = w.Write(ctx, data) + require.NoError(t, err) + err = w.Close(ctx) + require.NoError(t, err) + + got, err := s.ReadFile(ctx, filename) + require.NoError(t, err) + cmp := bytes.Compare(data, got) + require.Zero(t, cmp) +} diff --git a/br/pkg/storage/parse_test.go b/br/pkg/storage/parse_test.go index 4e8884b557961..0669564961c77 100644 --- a/br/pkg/storage/parse_test.go +++ b/br/pkg/storage/parse_test.go @@ -147,6 +147,12 @@ func TestCreateStorage(t *testing.T) { require.Equal(t, "https://gcs.example.com/", gcs.Endpoint) require.Equal(t, "fakeCredentials", gcs.CredentialsBlob) + s, err = ParseBackend("gcs://bucket?endpoint=http://127.0.0.1/", gcsOpt) + require.NoError(t, err) + gcs = s.GetGcs() + require.NotNil(t, gcs) + require.Equal(t, "http://127.0.0.1/", gcs.Endpoint) + err = os.WriteFile(fakeCredentialsFile, []byte("fakeCreds2"), credFilePerm) require.NoError(t, err) s, err = ParseBackend("gs://bucket4/backup/?credentials-file="+url.QueryEscape(fakeCredentialsFile), nil) diff --git a/go.mod b/go.mod index 02ab1ea21f3ae..fbdb378bf6563 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( github.com/fatih/color v1.15.0 github.com/fsouza/fake-gcs-server v1.44.0 github.com/go-ldap/ldap/v3 v3.4.4 + github.com/go-resty/resty/v2 v2.7.0 github.com/go-sql-driver/mysql v1.7.1 github.com/gogo/protobuf v1.3.2 github.com/golang/protobuf v1.5.3 diff --git a/go.sum b/go.sum index aa001d8a15894..1ff7a08d4dc50 100644 --- a/go.sum +++ b/go.sum @@ -295,6 +295,8 @@ github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab/go.mod h1:/P9AE github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= +github.com/go-resty/resty/v2 v2.7.0 h1:me+K9p3uhSmXtrBZ4k9jcEAfJmuC8IivWHwaLZwPrFY= +github.com/go-resty/resty/v2 v2.7.0/go.mod h1:9PWDzw47qPphMRFfhsyk0NnSgvluHcljSMVIq3w7q0I= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= @@ -1095,6 +1097,7 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210427231257-85d9c07bbe3a/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211029224645-99673261e6eb/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220517181318-183a9ca12b87/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk=