diff --git a/retrievalprovider/httpretrieval/multi_reader.go b/retrievalprovider/httpretrieval/multi_reader.go new file mode 100644 index 00000000..52f41c7d --- /dev/null +++ b/retrievalprovider/httpretrieval/multi_reader.go @@ -0,0 +1,142 @@ +package httpretrieval + +import ( + "errors" + "io" + + "github.com/filecoin-project/go-padreader" +) + +var errSeeker = errors.New("seeker can't seek") +var errWhence = errors.New("seek: invalid whence") + +type multiReader struct { + reader io.ReadSeeker + readerSize uint64 + readerOffset int + + nullReader io.Reader + nullReaderSize uint64 + nullReaderOffset int +} + +func newMultiReader(r io.ReadSeeker, size uint64) *multiReader { + padSize := padreader.PaddedSize(size) + nullReaderSize := uint64(padSize) - size + return &multiReader{ + reader: r, + readerSize: size, + + nullReader: io.LimitReader(nullReader{}, int64(nullReaderSize)), + nullReaderSize: nullReaderSize, + } +} + +func (mr *multiReader) Read(p []byte) (int, error) { + if int(mr.readerSize)-mr.readerOffset >= len(p) { + n, err := mr.reader.Read(p) + mr.readerOffset += n + return n, err + } + + var n int + var err error + remain := int(mr.readerSize) - mr.readerOffset + if remain > 0 { + n, err = mr.reader.Read(p[:remain]) + mr.readerOffset += n + if err != nil { + return n, err + } + } + + remain = int(mr.nullReaderSize) - mr.nullReaderOffset + if remain <= 0 { + return 0, io.EOF + } + if len(p)-n > remain { + n2, err := mr.nullReader.Read(p[n : remain+n]) + mr.nullReaderOffset += n2 + return n + n2, err + } + + n2, err := mr.nullReader.Read(p[n:]) + mr.nullReaderOffset += n2 + + return n + n2, err +} + +func (mr *multiReader) Seek(offset int64, whence int) (int64, error) { + switch whence { + default: + return 0, errWhence + case io.SeekStart: + seekOffset2 := 0 + seekOffset := offset + if offset > int64(mr.readerSize) { + seekOffset = int64(mr.readerSize) + seekOffset2 = int(offset - int64(mr.readerSize)) + } + _, err := mr.reader.Seek(seekOffset, whence) + if err != nil { + return 0, errSeeker + } + mr.readerOffset = int(seekOffset) + mr.nullReaderOffset = seekOffset2 + return offset, nil + case io.SeekCurrent: + if offset == 0 { + return 0, nil + } + + if offset > 0 { + if mr.readerSize > uint64(mr.readerOffset) { + remain := int64(mr.readerSize) - int64(mr.readerOffset) + seekOffset := offset + if offset > remain { + seekOffset = remain + mr.nullReaderOffset = int(offset) - int(remain) + } + _, err := mr.reader.Seek(seekOffset, whence) + if err != nil { + return 0, errSeeker + } + return offset, nil + } + mr.nullReaderOffset = +int(offset) + return offset, nil + } + if offset+int64(mr.nullReaderOffset) < 0 { + mr.nullReaderOffset = 0 + mr.readerOffset = int(offset) + mr.nullReaderOffset + int(mr.readerSize) + _, err := mr.reader.Seek(offset+int64(mr.nullReaderOffset), whence) + if err != nil { + return 0, errSeeker + } + return offset, nil + } + mr.nullReaderOffset += int(offset) + + return offset, nil + case io.SeekEnd: + _, err := mr.reader.Seek(0, whence) + if err != nil { + return 0, err + } + mr.readerOffset = int(mr.readerSize) + mr.nullReaderOffset = int(mr.nullReaderSize) + return int64(mr.readerSize) + int64(mr.nullReaderSize) + offset, nil + } +} + +var _ io.ReadSeeker = &multiReader{} + +type nullReader struct{} + +// Read writes NUL bytes into the provided byte slice. +func (nr nullReader) Read(b []byte) (int, error) { + for i := range b { + b[i] = 0 + } + return len(b), nil +} diff --git a/retrievalprovider/httpretrieval/multi_reader_test.go b/retrievalprovider/httpretrieval/multi_reader_test.go new file mode 100644 index 00000000..0d75c701 --- /dev/null +++ b/retrievalprovider/httpretrieval/multi_reader_test.go @@ -0,0 +1,113 @@ +package httpretrieval + +import ( + "crypto/rand" + "io" + "os" + "path/filepath" + "testing" + + "github.com/filecoin-project/go-padreader" + + "github.com/stretchr/testify/assert" +) + +func TestMultiReader(t *testing.T) { + dir := t.TempDir() + size := 10 + paddedSize := int(padreader.PaddedSize(uint64(size))) + buf := make([]byte, size) + _, err := rand.Read(buf) + assert.NoError(t, err) + f, err := os.Create(filepath.Join(dir, "test")) + assert.NoError(t, err) + _, err = f.Write(buf) + assert.NoError(t, err) + defer f.Close() // nolint + + _, err = f.Seek(0, io.SeekStart) + assert.NoError(t, err) + r := newMultiReader(f, uint64(size)) + buf2 := make([]byte, 2*size) + n, err := r.Read(buf2) + assert.NoError(t, err) + assert.Equal(t, 2*size, n) + assert.Equal(t, buf, buf2[:size]) + assert.Equal(t, make([]byte, size), buf2[size:]) + + _, err = f.Seek(0, io.SeekStart) + assert.NoError(t, err) + r = newMultiReader(f, uint64(size)) + buf2 = make([]byte, size) + n, err = r.Read(buf2) + assert.NoError(t, err) + assert.Equal(t, size, n) + assert.Equal(t, buf, buf2) + + _, err = f.Seek(0, io.SeekStart) + assert.NoError(t, err) + r = newMultiReader(f, uint64(size)) + buf2 = make([]byte, size*100) + n, err = r.Read(buf2) + assert.NoError(t, err) + assert.Equal(t, paddedSize, n) + assert.Equal(t, buf, buf2[:size]) + assert.Equal(t, make([]byte, paddedSize-size), buf2[size:paddedSize]) +} + +func TestMultiReaderSeek(t *testing.T) { + dir := t.TempDir() + size := 10 + paddedSize := int(padreader.PaddedSize(uint64(size))) + buf := make([]byte, size) + _, err := rand.Read(buf) + assert.NoError(t, err) + f, err := os.Create(filepath.Join(dir, "test")) + assert.NoError(t, err) + _, err = f.Write(buf) + assert.NoError(t, err) + defer f.Close() // nolint + + _, err = f.Seek(0, io.SeekStart) + assert.NoError(t, err) + r := newMultiReader(f, uint64(size)) + + var zero int64 + ret, err := r.Seek(zero, io.SeekStart) + assert.NoError(t, err) + assert.Equal(t, zero, ret) + + ret, err = r.Seek(zero, io.SeekEnd) + assert.NoError(t, err) + assert.Equal(t, int64(paddedSize), ret) + + ret, err = r.Seek(zero, io.SeekCurrent) + assert.NoError(t, err) + assert.Equal(t, zero, ret) + + for _, offset := range []int{1, 5, 10, 15, 50, paddedSize, 200} { + buf2 := make([]byte, size) + r = newMultiReader(f, uint64(size)) + + ret, err = r.Seek(int64(offset), io.SeekStart) + assert.NoError(t, err) + assert.Equal(t, int64(offset), ret) + + n, err := r.Read(buf2) + if offset >= paddedSize { + assert.Equal(t, io.EOF, err) + assert.Equal(t, 0, n) + continue + } + assert.NoError(t, err) + assert.Equal(t, size, n) + if offset <= size { + assert.Equal(t, buf[offset:size], buf2[:size-offset]) + assert.Equal(t, make([]byte, offset), buf2[n-offset:]) + } else { + assert.Equal(t, make([]byte, size), buf2) + } + } + + // todo: test r.Seek(zero, io.SeekCurrent) +} diff --git a/retrievalprovider/httpretrieval/server.go b/retrievalprovider/httpretrieval/server.go index 75cd51ef..82343b7a 100644 --- a/retrievalprovider/httpretrieval/server.go +++ b/retrievalprovider/httpretrieval/server.go @@ -3,15 +3,20 @@ package httpretrieval import ( "compress/gzip" "context" + "errors" "fmt" "io" "net/http" + "net/textproto" + "strconv" "strings" "time" "github.com/NYTimes/gziphandler" "github.com/filecoin-project/go-fil-markets/storagemarket" "github.com/filecoin-project/go-fil-markets/stores" + "github.com/filecoin-project/go-padreader" + "github.com/filecoin-project/go-state-types/abi" marketAPI "github.com/filecoin-project/venus/venus-shared/api/market/v1" "github.com/filecoin-project/venus/venus-shared/types" marketTypes "github.com/filecoin-project/venus/venus-shared/types/market" @@ -26,6 +31,10 @@ const ( ipfsBasePath = "/ipfs/" ) +// errNoOverlap is returned by serveContent's parseRange if first-byte-pos of +// all of the byte-range-spec values is greater than the content size. +var errNoOverlap = errors.New("invalid range: failed to overlap") + var log = logging.Logger("httpserver") type Server struct { @@ -97,7 +106,14 @@ func (s *Server) retrievalByPieceCID(w http.ResponseWriter, r *http.Request) { } defer mountReader.Close() // nolint - serveContent(w, r, mountReader, log) + contentReader, err := handleRangeHeader(r.Header.Get("Range"), mountReader, len) + if err != nil { + log.Warnf("handleRangeHeader failed, Range: %s, error: %v", r.Header.Get("Range"), err) + badResponse(w, http.StatusInternalServerError, err) + return + } + + serveContent(w, r, contentReader, log) log.Info("end retrieval deal") } @@ -197,6 +213,102 @@ func badResponse(w http.ResponseWriter, code int, err error) { w.Write([]byte("Error: " + err.Error())) // nolint } +func handleRangeHeader(r string, mountReader io.ReadSeeker, carSize int64) (io.ReadSeeker, error) { + paddedSize := padreader.PaddedSize(uint64(carSize)) + if paddedSize == abi.UnpaddedPieceSize(carSize) { + return mountReader, nil + } + + ranges, err := parseRange(r, int64(paddedSize)) + if err != nil { + return nil, err + } + + for _, r := range ranges { + if r[0]+r[1] >= carSize { + return newMultiReader(mountReader, uint64(carSize)), nil + } + } + + return mountReader, nil +} + +// parseRange parses a Range header string as per RFC 7233. +// errNoOverlap is returned if none of the ranges overlap. +func parseRange(s string, size int64) ([][2]int64, error) { + if s == "" { + return nil, nil // header not present + } + const b = "bytes=" + if !strings.HasPrefix(s, b) { + return nil, errors.New("invalid range") + } + var ranges [][2]int64 + noOverlap := false + for _, ra := range strings.Split(s[len(b):], ",") { + ra = textproto.TrimString(ra) + if ra == "" { + continue + } + start, end, ok := strings.Cut(ra, "-") + if !ok { + return nil, errors.New("invalid range") + } + start, end = textproto.TrimString(start), textproto.TrimString(end) + r := [2]int64{} + if start == "" { + // If no start is specified, end specifies the + // range start relative to the end of the file, + // and we are dealing with + // which has to be a non-negative integer as per + // RFC 7233 Section 2.1 "Byte-Ranges". + if end == "" || end[0] == '-' { + return nil, errors.New("invalid range") + } + i, err := strconv.ParseInt(end, 10, 64) + if i < 0 || err != nil { + return nil, errors.New("invalid range") + } + if i > size { + i = size + } + r[0] = size - i + r[1] = size - r[0] + } else { + i, err := strconv.ParseInt(start, 10, 64) + if err != nil || i < 0 { + return nil, errors.New("invalid range") + } + if i >= size { + // If the range begins after the size of the content, + // then it does not overlap. + noOverlap = true + continue + } + r[0] = i + if end == "" { + // If no end is specified, range extends to end of the file. + r[1] = size - r[0] + } else { + i, err := strconv.ParseInt(end, 10, 64) + if err != nil || r[0] > i { + return nil, errors.New("invalid range") + } + if i >= size { + i = size - 1 + } + r[1] = i - r[0] + 1 + } + } + ranges = append(ranges, r) + } + if noOverlap && len(ranges) == 0 { + // The specified ranges did not overlap with the content. + return nil, errNoOverlap + } + return ranges, nil +} + // writeErrorWatcher calls onError if there is an error writing to the writer type writeErrorWatcher struct { http.ResponseWriter diff --git a/retrievalprovider/httpretrieval/server_test.go b/retrievalprovider/httpretrieval/server_test.go index 5eebd8d5..7455567c 100644 --- a/retrievalprovider/httpretrieval/server_test.go +++ b/retrievalprovider/httpretrieval/server_test.go @@ -14,6 +14,7 @@ import ( "time" dagstore2 "github.com/filecoin-project/dagstore" + "github.com/filecoin-project/go-padreader" "github.com/filecoin-project/venus/venus-shared/api/market/v1/mock" "github.com/filecoin-project/venus/venus-shared/types" "github.com/filecoin-project/venus/venus-shared/types/market" @@ -214,3 +215,89 @@ func TestTrustless(t *testing.T) { assert.NoError(t, resp.Body.Close()) } } + +func TestRetrievalPaddingPiece(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tmpDri := t.TempDir() + cfg := config.DefaultMarketConfig + cfg.Home.HomeDir = tmpDri + cfg.PieceStorage.Fs = []*config.FsPieceStorage{ + { + Name: "test", + ReadOnly: false, + Path: tmpDri, + }, + } + assert.NoError(t, config.SaveConfig(cfg)) + + pieceStr := "baga6ea4seaqpzcr744w2rvqhkedfqbuqrbo7xtkde2ol6e26khu3wni64nbpaeq" + piece, err := cid.Decode(pieceStr) + assert.NoError(t, err) + buf := &bytes.Buffer{} + f, err := os.Create(filepath.Join(tmpDri, pieceStr)) + assert.NoError(t, err) + for i := 0; i < 100; i++ { + buf.WriteString("TEST TEST\n") + } + _, err = f.Write(buf.Bytes()) + assert.NoError(t, err) + assert.NoError(t, f.Close()) + + pieceStorage, err := piecestorage.NewPieceStorageManager(&cfg.PieceStorage) + assert.NoError(t, err) + ctrl := gomock.NewController(t) + m := mock.NewMockIMarket(ctrl) + m.EXPECT().MarketListIncompleteDeals(gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, p *market.StorageDealQueryParams) ([]market.MinerDeal, error) { + if p.PieceCID != pieceStr { + return nil, fmt.Errorf("not found deal") + } + return append([]market.MinerDeal{}, market.MinerDeal{ClientDealProposal: types.ClientDealProposal{Proposal: types.DealProposal{PieceCID: piece}}}), nil + }).AnyTimes() + + s, err := NewServer(ctx, pieceStorage, m, nil) + assert.NoError(t, err) + port := "34897" + startHTTPServer(ctx, t, port, s) + + carSize := len(buf.Bytes()) + paddedSize := padreader.PaddedSize(uint64(carSize)) + + cases := []struct { + r string + expect []byte + }{ + { + r: fmt.Sprintf("%d-%d", 0, 99), + expect: buf.Bytes()[0:100], + }, + { + r: fmt.Sprintf("%d-%d", 0, carSize-1), + expect: buf.Bytes(), + }, + { + r: fmt.Sprintf("%d-%d", 0, carSize+10), + expect: append(buf.Bytes(), make([]byte, 11)...), + }, + { + r: "0-", + expect: append(buf.Bytes(), make([]byte, int(paddedSize)-carSize)...), + }, + } + + for _, c := range cases { + url := fmt.Sprintf("http://127.0.0.1:%s/piece/%s", port, pieceStr) + req, err := http.NewRequest(http.MethodGet, url, nil) + assert.NoError(t, err) + req.Header.Set("Range", fmt.Sprintf("bytes=%s", c.r)) + resp, err := http.DefaultClient.Do(req) + assert.NoError(t, err) + defer resp.Body.Close() // nolint + + data, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, c.expect, data) + } +}