diff --git a/src/net/http/fs.go b/src/net/http/fs.go index af7511a7a4bd7e..287980df3d876e 100644 --- a/src/net/http/fs.go +++ b/src/net/http/fs.go @@ -374,8 +374,13 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, } w.WriteHeader(code) - if r.Method != "HEAD" { - io.CopyN(w, sendContent, sendSize) + if r.Method != MethodHead { + if sendSize == size { + // use Copy in the non-range case to make use of WriterTo if available + io.Copy(w, sendContent) + } else { + io.CopyN(w, sendContent, sendSize) + } } } diff --git a/src/net/http/fs_test.go b/src/net/http/fs_test.go index 861e70caf23963..c8778458e7ddbd 100644 --- a/src/net/http/fs_test.go +++ b/src/net/http/fs_test.go @@ -928,6 +928,7 @@ func testServeContent(t *testing.T, mode testMode) { wantContentType string wantContentRange string wantStatus int + wantContent []byte } htmlModTime := mustStat(t, "testdata/index.html").ModTime() tests := map[string]testCase{ @@ -1143,6 +1144,24 @@ func testServeContent(t *testing.T, mode testMode) { wantStatus: 412, wantLastMod: htmlModTime.UTC().Format(TimeFormat), }, + "uses_writeTo_if_available_and_non-range": { + content: &panicOnNonWriterTo{seekWriterTo: strings.NewReader("foobar")}, + serveContentType: "text/plain; charset=utf-8", + wantContentType: "text/plain; charset=utf-8", + wantStatus: StatusOK, + wantContent: []byte("foobar"), + }, + "do_not_use_writeTo_for_range_requests": { + content: &panicOnWriterTo{ReadSeeker: strings.NewReader("foobar")}, + serveContentType: "text/plain; charset=utf-8", + reqHeader: map[string]string{ + "Range": "bytes=0-4", + }, + wantContentType: "text/plain; charset=utf-8", + wantContentRange: "bytes 0-4/6", + wantStatus: StatusPartialContent, + wantContent: []byte("fooba"), + }, } for testName, tt := range tests { var content io.ReadSeeker @@ -1156,7 +1175,8 @@ func testServeContent(t *testing.T, mode testMode) { } else { content = tt.content } - for _, method := range []string{"GET", "HEAD"} { + contentOut := &strings.Builder{} + for _, method := range []string{MethodGet, MethodHead} { //restore content in case it is consumed by previous method if content, ok := content.(*strings.Reader); ok { content.Seek(0, io.SeekStart) @@ -1182,7 +1202,8 @@ func testServeContent(t *testing.T, mode testMode) { if err != nil { t.Fatal(err) } - io.Copy(io.Discard, res.Body) + contentOut.Reset() + io.Copy(contentOut, res.Body) res.Body.Close() if res.StatusCode != tt.wantStatus { t.Errorf("test %q using %q: got status = %d; want %d", testName, method, res.StatusCode, tt.wantStatus) @@ -1196,10 +1217,28 @@ func testServeContent(t *testing.T, mode testMode) { if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e { t.Errorf("test %q using %q: got last-modified = %q, want %q", testName, method, g, e) } + if g, e := contentOut.String(), tt.wantContent; e != nil && method == MethodGet && g != string(e) { + t.Errorf("test %q using %q: got unexpected content %q, want %q", testName, method, g, e) + } } } } +type seekWriterTo interface { + io.Seeker + io.WriterTo +} + +type panicOnNonWriterTo struct { + io.Reader + seekWriterTo +} + +type panicOnWriterTo struct { + io.ReadSeeker + io.WriterTo +} + // Issue 12991 func TestServerFileStatError(t *testing.T) { rec := httptest.NewRecorder() diff --git a/src/net/sendfile_linux.go b/src/net/sendfile_linux.go index 9a7d0058032f13..2791826e7a04e9 100644 --- a/src/net/sendfile_linux.go +++ b/src/net/sendfile_linux.go @@ -7,7 +7,7 @@ package net import ( "internal/poll" "io" - "os" + "syscall" ) // sendFile copies the contents of r to c using the sendfile @@ -27,7 +27,10 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { return 0, nil, true } } - f, ok := r.(*os.File) + f, ok := r.(interface { + Fd() uintptr // not used, but limits the type to *os.File + SyscallConn() (syscall.RawConn, error) + }) if !ok { return 0, nil, false } diff --git a/src/net/sendfile_unix_alt.go b/src/net/sendfile_unix_alt.go index 5cb65ee7670c49..85518390a4170a 100644 --- a/src/net/sendfile_unix_alt.go +++ b/src/net/sendfile_unix_alt.go @@ -9,7 +9,8 @@ package net import ( "internal/poll" "io" - "os" + "io/fs" + "syscall" ) // sendFile copies the contents of r to c using the sendfile @@ -34,7 +35,12 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) { return 0, nil, true } } - f, ok := r.(*os.File) + f, ok := r.(interface { + io.Seeker + Fd() uintptr // not used, but limits the type to *os.File + Stat() (fs.FileInfo, error) + SyscallConn() (syscall.RawConn, error) + }) if !ok { return 0, nil, false } diff --git a/src/net/sendfile_windows.go b/src/net/sendfile_windows.go index 59b1b0d5c1dd85..8f79f14d8b1f3a 100644 --- a/src/net/sendfile_windows.go +++ b/src/net/sendfile_windows.go @@ -7,7 +7,6 @@ package net import ( "internal/poll" "io" - "os" "syscall" ) @@ -29,7 +28,9 @@ func sendFile(fd *netFD, r io.Reader) (written int64, err error, handled bool) { } } - f, ok := r.(*os.File) + f, ok := r.(interface { + Fd() uintptr + }) if !ok { return 0, nil, false }