From d3e979a431a95557057ed528cfc9ed3fd6595781 Mon Sep 17 00:00:00 2001 From: jhwz <52683873+jhwz@users.noreply.github.com> Date: Tue, 22 Mar 2022 08:03:40 +1300 Subject: [PATCH] Return underlying Seeker from Identify if possible (#327) * [feat] don't require io.Seeker for identify * tidy up * Refactor and simplify with some bug fixes * Clarify returned Reader in godoc comment * if underlying reader supports seek use that * update comment * Update formats.go Co-authored-by: Matt Holt Co-authored-by: Matthew Holt --- formats.go | 7 +++++++ formats_test.go | 23 +++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/formats.go b/formats.go index 99ce196f..a63a9691 100644 --- a/formats.go +++ b/formats.go @@ -319,7 +319,14 @@ func (rr *rewindReader) rewind() { // bytes, then from the underlying stream. After calling this, // no more rewinding is allowed since reads from the stream are // not recorded, so rewinding properly is impossible. +// If the underlying reader implements io.Seeker, then the +// underlying reader will be used directly. func (rr *rewindReader) reader() io.Reader { + if ras, ok := rr.Reader.(io.Seeker); ok { + if _, err := ras.Seek(0, io.SeekStart); err == nil { + return rr.Reader + } + } return io.MultiReader(bytes.NewReader(rr.buf.Bytes()), rr.Reader) } diff --git a/formats_test.go b/formats_test.go index a1d0eb74..2531ceb1 100644 --- a/formats_test.go +++ b/formats_test.go @@ -387,3 +387,26 @@ func TestIdentifyFindFormatByStreamContent(t *testing.T) { }) } } + +func TestIdentifyAndOpenZip(t *testing.T) { + f, err := os.Open("testdata/test.zip") + checkErr(t, err, "opening zip") + defer f.Close() + + format, reader, err := Identify("test.zip", f) + checkErr(t, err, "identifying zip") + if format.Name() != ".zip" { + t.Fatalf("unexpected format found: expected=.zip actual:%s", format.Name()) + } + + err = format.(Extractor).Extract(context.Background(), reader, nil, func(ctx context.Context, f File) error { + rc, err := f.Open() + if err != nil { + return err + } + defer rc.Close() + _, err = io.ReadAll(rc) + return err + }) + checkErr(t, err, "extracting zip") +}