Skip to content

net/http: use Copy in ServeContent if CopyN not needed #65106

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/net/http/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,13 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time,
}
w.WriteHeader(code)

if r.Method != "HEAD" {
io.CopyN(w, sendContent, sendSize)
if r.Method != MethodHead {
if sendSize == size {
// use Copy in the non-range case to make use of WriterTo if available
io.Copy(w, sendContent)
} else {
io.CopyN(w, sendContent, sendSize)
}
}
}

Expand Down
43 changes: 41 additions & 2 deletions src/net/http/fs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,7 @@ func testServeContent(t *testing.T, mode testMode) {
wantContentType string
wantContentRange string
wantStatus int
wantContent []byte
}
htmlModTime := mustStat(t, "testdata/index.html").ModTime()
tests := map[string]testCase{
Expand Down Expand Up @@ -1143,6 +1144,24 @@ func testServeContent(t *testing.T, mode testMode) {
wantStatus: 412,
wantLastMod: htmlModTime.UTC().Format(TimeFormat),
},
"uses_writeTo_if_available_and_non-range": {
content: &panicOnNonWriterTo{seekWriterTo: strings.NewReader("foobar")},
serveContentType: "text/plain; charset=utf-8",
wantContentType: "text/plain; charset=utf-8",
wantStatus: StatusOK,
wantContent: []byte("foobar"),
},
"do_not_use_writeTo_for_range_requests": {
content: &panicOnWriterTo{ReadSeeker: strings.NewReader("foobar")},
serveContentType: "text/plain; charset=utf-8",
reqHeader: map[string]string{
"Range": "bytes=0-4",
},
wantContentType: "text/plain; charset=utf-8",
wantContentRange: "bytes 0-4/6",
wantStatus: StatusPartialContent,
wantContent: []byte("fooba"),
},
}
for testName, tt := range tests {
var content io.ReadSeeker
Expand All @@ -1156,7 +1175,8 @@ func testServeContent(t *testing.T, mode testMode) {
} else {
content = tt.content
}
for _, method := range []string{"GET", "HEAD"} {
contentOut := &strings.Builder{}
for _, method := range []string{MethodGet, MethodHead} {
//restore content in case it is consumed by previous method
if content, ok := content.(*strings.Reader); ok {
content.Seek(0, io.SeekStart)
Expand All @@ -1182,7 +1202,8 @@ func testServeContent(t *testing.T, mode testMode) {
if err != nil {
t.Fatal(err)
}
io.Copy(io.Discard, res.Body)
contentOut.Reset()
io.Copy(contentOut, res.Body)
res.Body.Close()
if res.StatusCode != tt.wantStatus {
t.Errorf("test %q using %q: got status = %d; want %d", testName, method, res.StatusCode, tt.wantStatus)
Expand All @@ -1196,10 +1217,28 @@ func testServeContent(t *testing.T, mode testMode) {
if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e {
t.Errorf("test %q using %q: got last-modified = %q, want %q", testName, method, g, e)
}
if g, e := contentOut.String(), tt.wantContent; e != nil && method == MethodGet && g != string(e) {
t.Errorf("test %q using %q: got unexpected content %q, want %q", testName, method, g, e)
}
}
}
}

type seekWriterTo interface {
io.Seeker
io.WriterTo
}

type panicOnNonWriterTo struct {
io.Reader
seekWriterTo
}

type panicOnWriterTo struct {
io.ReadSeeker
io.WriterTo
}

// Issue 12991
func TestServerFileStatError(t *testing.T) {
rec := httptest.NewRecorder()
Expand Down
7 changes: 5 additions & 2 deletions src/net/sendfile_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package net
import (
"internal/poll"
"io"
"os"
"syscall"
)

// sendFile copies the contents of r to c using the sendfile
Expand All @@ -27,7 +27,10 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
return 0, nil, true
}
}
f, ok := r.(*os.File)
f, ok := r.(interface {
Fd() uintptr // not used, but limits the type to *os.File
SyscallConn() (syscall.RawConn, error)
})
if !ok {
return 0, nil, false
}
Expand Down
10 changes: 8 additions & 2 deletions src/net/sendfile_unix_alt.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ package net
import (
"internal/poll"
"io"
"os"
"io/fs"
"syscall"
)

// sendFile copies the contents of r to c using the sendfile
Expand All @@ -34,7 +35,12 @@ func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
return 0, nil, true
}
}
f, ok := r.(*os.File)
f, ok := r.(interface {
io.Seeker
Fd() uintptr // not used, but limits the type to *os.File
Stat() (fs.FileInfo, error)
SyscallConn() (syscall.RawConn, error)
})
if !ok {
return 0, nil, false
}
Expand Down
5 changes: 3 additions & 2 deletions src/net/sendfile_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ package net
import (
"internal/poll"
"io"
"os"
"syscall"
)

Expand All @@ -29,7 +28,9 @@ func sendFile(fd *netFD, r io.Reader) (written int64, err error, handled bool) {
}
}

f, ok := r.(*os.File)
f, ok := r.(interface {
Fd() uintptr
})
if !ok {
return 0, nil, false
}
Expand Down